如何从python中的.pb文件恢复Tensorflow模型?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/50632258/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to restore Tensorflow model from .pb file in python?
提问by vizsatiz
I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions. I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.
我有一个 tensorflow .pb 文件,我想将其加载到 python DNN 中,恢复图形并获得预测。我这样做是为了测试创建的 .pb 文件是否可以做出类似于普通 Saver.save() 模型的预测。
My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file
我的基本问题是,当我使用上述 .pb 文件在 Android 上进行预测时,得到的预测值非常不同
My .pb file creation code:
我的 .pb 文件创建代码:
frozen_graph = tf.graph_util.convert_variables_to_constants(
session,
session.graph_def,
['outputLayer/Softmax']
)
with open('frozen_model.pb', 'wb') as f:
f.write(frozen_graph.SerializeToString())
So I have two major concerns:
所以我有两个主要的担忧:
- How can I load the above mentioned .pb file to python Tensorflow model ?
- Why am I getting completely different values of prediction in python and android ?
- 如何将上述 .pb 文件加载到 python Tensorflow 模型?
- 为什么我在 python 和 android 中得到完全不同的预测值?
回答by sahu
The following code will read the model and print out the names of the nodes in the graph.
以下代码将读取模型并打印出图中节点的名称。
import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
print(names)
You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model. You can use the freeze_graph.py(link) for getting a correctly stored graph.
您正确地冻结了图形,这就是为什么您得到不同结果的原因基本上权重没有存储在您的模型中。您可以使用freeze_graph.py(链接)获取正确存储的图形。
回答by caylus
Here is the updated code for tensorflow 2.
这是 tensorflow 2 的更新代码。
import tensorflow as tf
GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
print("load graph")
with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
print(names)