Python 在检查点 Tensorflow 中找不到键 <variable_name>
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/45179556/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Key <variable_name> not found in checkpoint Tensorflow
提问by YellowPillow
I'm using Tensorflow v1.1 and I've been trying to figure out how to use my EMA'ed weights for inference, but no matter what I do I keep getting the error
我正在使用 Tensorflow v1.1 并且我一直在试图弄清楚如何使用我的 EMA 权重进行推理,但无论我做什么,我都会收到错误消息
Not found: Key W/ExponentialMovingAverage not found in checkpoint
未找到:在检查点中未找到键 W/ExponentialMovingAverage
even though when I loop through and print out all the tf.global_variables
the key exists
即使当我循环并打印出所有tf.global_variables
密钥时
Here is a reproducible script heavily adapted from Facenet'sunit test:
这是一个从Facenet 的单元测试中大量改编的可重现脚本:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)
# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
train_op = tf.group(averages_op)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.trainable_variables())
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for _ in range(201):
sess.run(train_op)
w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')
saver.save(sess, os.path.join("model_ex1"))
tf.reset_default_graph()
tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()
print('------------------------------------------------------')
for var in tf.global_variables():
print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
print('ema variable: ' + var.op.name)
print('------------------------------------------------------')
mode = 1
restore_vars = {}
if mode == 0:
ema = tf.train.ExponentialMovingAverage(1.0)
for var in tf.trainable_variables():
print('%s: %s' % (ema.average_name(var), var.op.name))
restore_vars[ema.average_name(var)] = var
elif mode == 1:
for var in tf.trainable_variables():
ema_name = var.op.name + '/ExponentialMovingAverage'
print('%s: %s' % (ema_name, var.op.name))
restore_vars[ema_name] = var
saver = tf.train.Saver(restore_vars, name='ema_restore')
saver.restore(sess, os.path.join("model_ex1")) # error happens here!
w_restored = sess.run('W:0')
b_restored = sess.run('b:0')
print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)
回答by Alexandre Passos
The key not found in checkpoint
error means that the variable exists in your model in memory but not in the serialized checkpoint file on disk.
该key not found in checkpoint
错误意味着该变量存在于内存中的模型中,但不存在于磁盘上的序列化检查点文件中。
You should use the inspect_checkpoint toolto understand what tensors are being saved in your checkpoint, and why some exponential moving averages are not being saved here.
您应该使用inspect_checkpoint 工具来了解检查点中保存了哪些张量,以及为什么这里没有保存一些指数移动平均线。
It's not clear from your repro example which line is supposed to trigger the error
从您的重现示例中不清楚哪一行应该触发错误
回答by Lerner Zhang
I'd like to add a method to use the trained variables in the checkpoint at best.
我想添加一种方法来最多使用检查点中经过训练的变量。
Keep in mind that all variables in the saver var_list should be contained in the checkpoint you configured. You can check those in the saver by:
请记住,保存程序 var_list 中的所有变量都应包含在您配置的检查点中。您可以通过以下方式检查保护程序中的那些:
print(restore_vars)
and those variables in the checkpoint by:
以及检查点中的那些变量:
vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1"))
in your case.
在你的情况下。
If the restore_vars are all included in vars_in_checkpoint then it will not raise the error, otherwise initialize all variables first:
如果 restore_vars 都包含在 vars_in_checkpoint 中,则不会引发错误,否则首先初始化所有变量:
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess.run(tf.variables_initializer(all_variables))
All variables will be initialized be those in or not in the checkpoint, then you can filter out those variables in restore_vars that are not included in the checkpoint(suppose all variable with ExponentialMovingAverage in their names are not in the checkpoint):
所有变量都将被初始化为检查点中或不在检查点中的变量,然后您可以过滤掉那些未包含在检查点中的 restore_vars 变量(假设所有名称中带有 ExponentialMovingAverage 的变量都不在检查点中):
temp_saver = tf.train.Saver(
var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name])
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
This may save some time compared to training the model from scratch. (In my scenario the restored variables make no significant improvement compared to training from scratch in the beginning, since all old optimizer variables are abandoned. But it can accelerate the optimization process significantly, I think, because it is like pretraining some variables)
与从头开始训练模型相比,这可能会节省一些时间。(在我的场景中,与一开始从头开始训练相比,恢复的变量没有显着改善,因为所有旧的优化器变量都被放弃了。但我认为它可以显着加速优化过程,因为它就像预训练一些变量)
Anyway, some variables are useful to be restored like embeddings and some layers and etc.
无论如何,一些变量对于恢复很有用,比如嵌入和一些层等等。