Python Tensorflow 变量范围:如果变量存在则重用
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/38545362/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Tensorflow variable scope: reuse if variable exists
提问by holdenlee
I want a piece of code that creates a variable within a scope if it doesn't exist, and access the variable if it already exists. I need it to be the samecode since it will be called multiple times.
我想要一段代码,如果它不存在,则在范围内创建一个变量,如果它已经存在,则访问该变量。我需要它是相同的代码,因为它将被多次调用。
However, Tensorflow needs me to specify whether I want to create or reuse the variable, like this:
但是,Tensorflow 需要我指定是要创建还是重用该变量,如下所示:
with tf.variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True): #reuse the second time
v = tf.get_variable("v", [1])
How can I get it to figure out whether to create or reuse it automatically? I.e., I want the above two blocks of code to be the sameand have the program run.
我怎样才能让它弄清楚是自动创建还是重用它?即,我希望上述两个代码块相同并运行程序。
采纳答案by rvinas
A ValueError
is raised in get_variable()
when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:
ValueError
在get_variable()
创建新变量且未声明形状时,或在变量创建期间违反重用时,会引发A。因此,你可以试试这个:
def get_scope_variable(scope_name, var, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
v = tf.get_variable(var, shape)
except ValueError:
scope.reuse_variables()
v = tf.get_variable(var)
return v
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
Note that the following also works:
请注意,以下内容也有效:
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
UPDATE.The new API supports auto-reusing now:
更新。新的 API 现在支持自动重用:
def get_scope_variable(scope, var, shape=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
v = tf.get_variable(var, shape)
return v
回答by Zhongyu Kuang
Although using "try...except..." clause works, I think a more elegant and maintainable way would be separate the variable initialization process with the "reuse" process.
尽管使用“try...except...”子句有效,但我认为更优雅和可维护的方法是将变量初始化过程与“重用”过程分开。
def initialize_variable(scope_name, var_name, shape):
with tf.variable_scope(scope_name) as scope:
v = tf.get_variable(var_name, shape)
scope.reuse_variable()
def get_scope_variable(scope_name, var_name):
with tf.variable_scope(scope_name, reuse=True):
v = tf.get_variable(var_name)
return v
Since often we only need to initialize variable ones, but reuse/share it for many times, separating the two processes make the code cleaner. Also this way, we won't need to go through the "try" clause every time to check if the variable has been created already or not.
由于通常我们只需要初始化变量,而是多次重用/共享它,因此将两个进程分开会使代码更清晰。同样这样,我们不需要每次都通过“try”子句来检查变量是否已经创建。
回答by Mikhail Mishin
New AUTO_REUSE option does the trick.
新的 AUTO_REUSE 选项可以解决问题。
From the tf.variable_scope API docs: if reuse=tf.AUTO_REUSE
, we create variables if they do not exist, and return them otherwise.
来自tf.variable_scope API docs: if reuse=tf.AUTO_REUSE
,如果变量不存在,我们创建变量,否则返回它们。
Basic example of sharing a variable AUTO_REUSE:
共享变量 AUTO_REUSE 的基本示例:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
回答by AlexP
We can write our abstraction over tf.varaible_scope
than uses reuse=None
on the first call and uses reuse=True
on the consequent calls:
我们可以写我们的抽象了tf.varaible_scope
比使用reuse=None
在首次调用和使用reuse=True
上随之而来的电话:
def variable_scope(name_or_scope, *args, **kwargs):
if isinstance(name_or_scope, str):
scope_name = tf.get_variable_scope().name + '/' + name_or_scope
elif isinstance(name_or_scope, tf.Variable):
scope_name = name_or_scope.name
if scope_name in variable_scope.scopes:
kwargs['reuse'] = True
else:
variable_scope.scopes.add(scope_name)
return tf.variable_scope(name_or_scope, *args, **kwargs)
variable_scope.scopes = set()
Usage:
用法:
with variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with variable_scope("foo"): #reuse the second time
v = tf.get_variable("v", [1])