Python TensorFlow 将图形保存到文件中/从文件加载图形

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/38947658/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 21:39:38  来源:igfitidea点击:

TensorFlow saving into/loading a graph from a file

pythontensorflowprotocol-buffers

提问by Technicolor

From what I've gathered so far, there are several different ways of dumping a TensorFlow graph into a file and then loading it into another program, but I haven't been able to find clear examples/information on how they work. What I already know is this:

从我到目前为止收集的信息来看,有几种不同的方法可以将 TensorFlow 图转储到文件中,然后将其加载到另一个程序中,但我一直无法找到有关它们如何工作的明确示例/信息。我已经知道的是:

  1. Save the model's variables into a checkpoint file (.ckpt) using a tf.train.Saver()and restore them later (source)
  2. Save a model into a .pb file and load it back in using tf.train.write_graph()and tf.import_graph_def()(source)
  3. Load in a model from a .pb file, retrain it, and dump it into a new .pb file using Bazel (source)
  4. Freeze the graph to save the graph and weights together (source)
  5. Use as_graph_def()to save the model, and for weights/variables, map them into constants (source)
  1. 使用 a 将模型的变量保存到检查点文件 (.ckpt) 中tf.train.Saver()并稍后恢复它们 ( source)
  2. 将模型保存到 .pb 文件中,然后使用tf.train.write_graph()tf.import_graph_def())将其加载回
  3. 从 .pb 文件加载模型,重新训练它,然后使用 Bazel 将其转储到新的 .pb 文件中(来源
  4. 冻结图形以将图形和权重保存在一起(来源
  5. 使用as_graph_def()保存模型,并为权重/变量,它们映射到常数(

However, I haven't been able to clear up several questions regarding these different methods:

但是,我一直无法澄清有关这些不同方法的几个问题:

  1. Regarding checkpoint files, do they only save the trained weights of a model? Could checkpoint files be loaded into a new program, and be used to run the model, or do they simply serve as ways to save the weights in a model at a certain time/stage?
  2. Regarding tf.train.write_graph(), are the weights/variables saved as well?
  3. Regarding Bazel, can it only save into/load from .pb files for retraining? Is there a simple Bazel command just to dump a graph into a .pb?
  4. Regarding freezing, can a frozen graph be loaded in using tf.import_graph_def()?
  5. The Android demo for TensorFlow loads in Google's Inception model from a .pb file. If I wanted to substitute my own .pb file, how would I go about doing that? Would I need to change any native code/methods?
  6. In general, what exactly is the difference between all these methods? Or more broadly, what is the difference between as_graph_def()/.ckpt/.pb?
  1. 关于检查点文件,它们是否只保存模型的训练权重?检查点文件是否可以加载到新程序中,并用于运行模型,或者它们只是作为在特定时间/阶段保存模型中权重的方法?
  2. 关于tf.train.write_graph(),权重/变量是否也保存了?
  3. 关于 Bazel,它只能保存到/从 .pb 文件加载以进行再培训吗?是否有一个简单的 Bazel 命令只是将图形转储到 .pb 中?
  4. 关于冻结,可以在使用中加载冻结图tf.import_graph_def()吗?
  5. TensorFlow 的 Android 演示从 .pb 文件加载到 Google 的 Inception 模型中。如果我想替换我自己的 .pb 文件,我该怎么做?我需要更改任何本机代码/方法吗?
  6. 一般来说,所有这些方法之间究竟有什么区别?或者更广泛地说,/. as_graph_def()ckpt/.pb之间有什么区别?

In short, what I'm looking for is a method to save both a graph (as in, the various operations and such) and its weights/variables into a file, which can then be used to load the graph and weights into another program, for use (not necessarily continuing/retraining).

简而言之,我正在寻找的是一种将图形(例如各种操作等)及其权重/变量保存到文件中的方法,然后可以使用该文件将图形和权重加载到另一个程序中,供使用(不一定要继续/再培训)。

Documentation about this topic isn't very straightforward, so any answers/information would be greatly appreciated.

有关此主题的文档不是很简单,因此将不胜感激任何答案/信息。

采纳答案by mrry

There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:

有很多方法可以解决在 TensorFlow 中保存模型的问题,这可能会让人有点困惑。依次回答每个子问题:

  1. The checkpoint files (produced e.g. by calling saver.save()on a tf.train.Saverobject) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights. Note that calling saver.save()also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorialfor more details.

  2. tf.train.write_graph()only writes the graph structure; not the weights.

  3. Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)

  4. A frozen graph can be loaded using tf.import_graph_def(). In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.

  5. The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputNameand outputNamestrings that are passed to TensorFlowClassifier.initializeTensorFlow().

  6. The GraphDefis the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDeflibraries, Keras, and skflowbuild on these mechanisms to provide more convenient ways to save and restore an entire model.

  1. 检查点文件(例如产生通过调用saver.save()一个上tf.train.Saver对象)只包含的权重,并且在相同程序中定义的任何其它变量。要在另一个程序中使用它们,您必须重新创建关联的图结构(例如,通过运行代码再次构建它,或调用tf.import_graph_def()),它会告诉 TensorFlow 如何处理这些权重。请注意,调用saver.save()还会生成一个包含 的文件MetaGraphDef,其中包含一个图表以及如何将来自检查点的权重与该图表相关联的详细信息。有关更多详细信息,请参阅教程

  2. tf.train.write_graph()只写图结构;不是权重。

  3. Bazel 与读取或写入 TensorFlow 图无关。(也许我误解了您的问题:请随时在评论中澄清。)

  4. 可以使用 加载冻结图tf.import_graph_def()。在这种情况下,权重(通常)嵌入在图中,因此您无需加载单独的检查点。

  5. 主要的变化是更新输入模型的张量的名称,以及从模型中获取的张量的名称。在 TensorFlow Android 演示中,这将对应于传递给的inputNameoutputName字符串TensorFlowClassifier.initializeTensorFlow()

  6. GraphDef是程序结构,它通常不会在训练过程中改变。检查点是训练过程状态的快照,通常在训练过程的每一步都会发生变化。因此,TensorFlow 对这些类型的数据使用不同的存储格式,低级 API 提供了不同的方式来保存和加载它们。更高级别的库,如MetaGraphDef图书馆,Kerasskflow对这些机制的构建提供更加便捷的方式来保存和恢复整个模型。

回答by Srihari Humbarwadi

You can try the following code:

您可以尝试以下代码:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)