Python TensorFlow 将图形保存到文件中/从文件加载图形
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/38947658/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
TensorFlow saving into/loading a graph from a file
提问by Technicolor
From what I've gathered so far, there are several different ways of dumping a TensorFlow graph into a file and then loading it into another program, but I haven't been able to find clear examples/information on how they work. What I already know is this:
从我到目前为止收集的信息来看,有几种不同的方法可以将 TensorFlow 图转储到文件中,然后将其加载到另一个程序中,但我一直无法找到有关它们如何工作的明确示例/信息。我已经知道的是:
- Save the model's variables into a checkpoint file (.ckpt) using a
tf.train.Saver()
and restore them later (source) - Save a model into a .pb file and load it back in using
tf.train.write_graph()
andtf.import_graph_def()
(source) - Load in a model from a .pb file, retrain it, and dump it into a new .pb file using Bazel (source)
- Freeze the graph to save the graph and weights together (source)
- Use
as_graph_def()
to save the model, and for weights/variables, map them into constants (source)
- 使用 a 将模型的变量保存到检查点文件 (.ckpt) 中
tf.train.Saver()
并稍后恢复它们 ( source) - 将模型保存到 .pb 文件中,然后使用
tf.train.write_graph()
和tf.import_graph_def()
(源)将其加载回 - 从 .pb 文件加载模型,重新训练它,然后使用 Bazel 将其转储到新的 .pb 文件中(来源)
- 冻结图形以将图形和权重保存在一起(来源)
- 使用
as_graph_def()
保存模型,并为权重/变量,它们映射到常数(源)
However, I haven't been able to clear up several questions regarding these different methods:
但是,我一直无法澄清有关这些不同方法的几个问题:
- Regarding checkpoint files, do they only save the trained weights of a model? Could checkpoint files be loaded into a new program, and be used to run the model, or do they simply serve as ways to save the weights in a model at a certain time/stage?
- Regarding
tf.train.write_graph()
, are the weights/variables saved as well? - Regarding Bazel, can it only save into/load from .pb files for retraining? Is there a simple Bazel command just to dump a graph into a .pb?
- Regarding freezing, can a frozen graph be loaded in using
tf.import_graph_def()
? - The Android demo for TensorFlow loads in Google's Inception model from a .pb file. If I wanted to substitute my own .pb file, how would I go about doing that? Would I need to change any native code/methods?
- In general, what exactly is the difference between all these methods? Or more broadly, what is the difference between
as_graph_def()
/.ckpt/.pb?
- 关于检查点文件,它们是否只保存模型的训练权重?检查点文件是否可以加载到新程序中,并用于运行模型,或者它们只是作为在特定时间/阶段保存模型中权重的方法?
- 关于
tf.train.write_graph()
,权重/变量是否也保存了? - 关于 Bazel,它只能保存到/从 .pb 文件加载以进行再培训吗?是否有一个简单的 Bazel 命令只是将图形转储到 .pb 中?
- 关于冻结,可以在使用中加载冻结图
tf.import_graph_def()
吗? - TensorFlow 的 Android 演示从 .pb 文件加载到 Google 的 Inception 模型中。如果我想替换我自己的 .pb 文件,我该怎么做?我需要更改任何本机代码/方法吗?
- 一般来说,所有这些方法之间究竟有什么区别?或者更广泛地说,/.
as_graph_def()
ckpt/.pb之间有什么区别?
In short, what I'm looking for is a method to save both a graph (as in, the various operations and such) and its weights/variables into a file, which can then be used to load the graph and weights into another program, for use (not necessarily continuing/retraining).
简而言之,我正在寻找的是一种将图形(例如各种操作等)及其权重/变量保存到文件中的方法,然后可以使用该文件将图形和权重加载到另一个程序中,供使用(不一定要继续/再培训)。
Documentation about this topic isn't very straightforward, so any answers/information would be greatly appreciated.
有关此主题的文档不是很简单,因此将不胜感激任何答案/信息。
采纳答案by mrry
There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:
有很多方法可以解决在 TensorFlow 中保存模型的问题,这可能会让人有点困惑。依次回答每个子问题:
The checkpoint files (produced e.g. by calling
saver.save()
on atf.train.Saver
object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or callingtf.import_graph_def()
), which tells TensorFlow what to do with those weights. Note that callingsaver.save()
also produces a file containing aMetaGraphDef
, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorialfor more details.tf.train.write_graph()
only writes the graph structure; not the weights.Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)
A frozen graph can be loaded using
tf.import_graph_def()
. In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the
inputName
andoutputName
strings that are passed toTensorFlowClassifier.initializeTensorFlow()
.The
GraphDef
is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as theMetaGraphDef
libraries, Keras, and skflowbuild on these mechanisms to provide more convenient ways to save and restore an entire model.
检查点文件(例如产生通过调用
saver.save()
一个上tf.train.Saver
对象)只包含的权重,并且在相同程序中定义的任何其它变量。要在另一个程序中使用它们,您必须重新创建关联的图结构(例如,通过运行代码再次构建它,或调用tf.import_graph_def()
),它会告诉 TensorFlow 如何处理这些权重。请注意,调用saver.save()
还会生成一个包含 的文件MetaGraphDef
,其中包含一个图表以及如何将来自检查点的权重与该图表相关联的详细信息。有关更多详细信息,请参阅教程。tf.train.write_graph()
只写图结构;不是权重。Bazel 与读取或写入 TensorFlow 图无关。(也许我误解了您的问题:请随时在评论中澄清。)
可以使用 加载冻结图
tf.import_graph_def()
。在这种情况下,权重(通常)嵌入在图中,因此您无需加载单独的检查点。主要的变化是更新输入模型的张量的名称,以及从模型中获取的张量的名称。在 TensorFlow Android 演示中,这将对应于传递给的
inputName
和outputName
字符串TensorFlowClassifier.initializeTensorFlow()
。这
GraphDef
是程序结构,它通常不会在训练过程中改变。检查点是训练过程状态的快照,通常在训练过程的每一步都会发生变化。因此,TensorFlow 对这些类型的数据使用不同的存储格式,低级 API 提供了不同的方式来保存和加载它们。更高级别的库,如MetaGraphDef
图书馆,Keras和skflow对这些机制的构建提供更加便捷的方式来保存和恢复整个模型。
回答by Srihari Humbarwadi
You can try the following code:
您可以尝试以下代码:
with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)