Python 如何在 TensorFlow 中使用 tf.get_variable 和 numpy 值初始化变量?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/38190365/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 20:28:33  来源:igfitidea点击:

How does one initialize a variable with tf.get_variable and a numpy value in TensorFlow?

pythonnumpytensorflow

提问by Pinocchio

I wanted to initialize some of the variable on my network with numpy values. For the sake of the example consider:

我想用 numpy 值初始化我网络上的一些变量。为了这个例子考虑:

init=np.random.rand(1,2)
tf.get_variable('var_name',initializer=init)

when I do that I get an error:

当我这样做时,我收到一个错误:

ValueError: Shape of a new variable (var_name) must be fully defined, but instead was <unknown>.

why is it that I am getting that error?

为什么我收到那个错误?

To try to fix it I tried doing:

为了尝试修复它,我尝试这样做:

tf.get_variable('var_name',initializer=init, shape=[1,2])

which yielded a even weirder error:

这产生了一个更奇怪的错误:

TypeError: 'numpy.ndarray' object is not callable

I tried reading the docs and examplesbut it didn't really help.

我尝试阅读文档和示例,但并没有真正帮助。

Is it not possible to initialize variables with numpy arrays with the get_variable method in TensorFlow?

是否无法使用 TensorFlow 中的 get_variable 方法使用 numpy 数组初始化变量?

回答by keveman

The following works :

以下工作:

init = tf.constant(np.random.rand(1, 2))
tf.get_variable('var_name', initializer=init)

The documentation for get_variableis a little lacking indeed. Just for your reference, the initializerargument has to be either a TensorFlow Tensorobject (which can be constructed by calling tf.constanton a numpyvalue in your case), or a 'callable' that takes two arguments, shapeand dtype, the shape and data type of the value that it's supposed to return. Again, in your case, you can write the following in case you wanted to use the 'callable' mechanism :

的文档get_variable确实有点缺乏。仅供您参考,该initializer参数必须是一个 TensorFlowTensor对象(可以通过在您的情况下调用tf.constant一个numpy值来构造),或者是一个带有两个参数的“可调用”对象,shape以及dtype值的形状和数据类型它应该回来。同样,在您的情况下,您可以编写以下内容,以防您想使用“可调用”机制:

init = lambda shape, dtype: np.random.rand(*shape)
tf.tf.get_variable('var_name', initializer=init, shape=[1, 2])

回答by Nezha

@keveman Answered well, and for supplement, there is the usage of tf.get_variable('var_name', initializer=init), the tensorflow document did give a comprehensive example.

@keveman 回答的很好,补充一下,还有 tf.get_variable('var_name', initializer=init)的用法,tensorflow 文档确实给出了一个全面的例子。

import numpy as np
import tensorflow as tf

value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    fitting shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]]

print('larger shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [3, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    larger shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]
[7.  7.  7.  7.]]

print('smaller shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 3], initializer = init)

    * <b>`ValueError`< / b > : Too many elements provided.Needed at most 6, but received 8

https://www.tensorflow.org/api_docs/python/tf/constant_initializer

https://www.tensorflow.org/api_docs/python/tf/constant_initializer

回答by James D

If the variable was already created (ie from some complex function), just use load.

如果变量已经创建(即从一些复杂的函数),只需使用load.

https://www.tensorflow.org/api_docs/python/tf/Variable#load

https://www.tensorflow.org/api_docs/python/tf/Variable#load

x_var = tf.Variable(tf.zeros((1, 2), tf.float32))
x_val = np.random.rand(1,2).astype(np.float32)

sess = tf.Session()
x_var.load(x_val, session=sess)

# test
assert np.all(sess.run(x_var) == x_val)