在tensorflow python API中, tf.get_variable 具有参数集合,以将创建的var添加到指定的集合。但是 tf.variable_scope 却没有。 将变量范围内的所有变量添加到某个集合中的建议方法是什么?
答案 0 :(得分:1)
我不相信有办法直接这样做。您可以在Tensorflow的github问题跟踪器上提交功能请求。
我可以建议您尝试两种解决方法:
迭代tf.all_variables()
的结果,并提取名称类似于".../scope_name/..."
的变量。范围名称以变量名称编码,由/
个字符分隔。
围绕tf.VariableScope和tf.get_variable()编写包装器,它将在范围内创建的变量存储在数据结构中。
我希望有所帮助!
答案 1 :(得分:0)
我设法做到了:
import tensorflow as tf
def var_1():
with tf.variable_scope("foo") as foo_scope:
assert foo_scope.name == "ll/foo"
a = tf.get_variable("a", [2, 2])
return foo_scope
def var_2(foo_scope):
with tf.variable_scope("bar"):
b = tf.get_variable("b", [2, 2])
with tf.variable_scope("baz") as other_scope:
c = tf.get_variable("c", [2, 2])
assert other_scope.name == "ll/bar/baz"
with tf.variable_scope(foo_scope) as foo_scope2:
d = tf.get_variable("d", [2, 2])
assert foo_scope2.name == "ll/foo" # Not changed.
def main():
with tf.variable_scope("ll"):
scp = var_1()
var_2(scp)
all_default_global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
my_collection = tf.get_collection('my_collection') # create my collection
ll_foo_variables = []
for variable in all_default_global_variables:
if "ll/foo" in variable.name:
ll_foo_variables.append(variable)
tf.add_to_collection('my_collection', ll_foo_variables)
variables_in_my_collection = tf.get_collection_ref("my_collection")
print(variables_in_my_collection)
main()
您可以在a,b,c和d的代码中看到,只有a和d具有相同的范围名称ll/foo
。
首先,我添加默认情况下在tf.GraphKeys.GLOBAL_VARIABLES集合中创建的所有变量,然后创建一个名为my_collection的集合,然后我只添加那些带有' ll / foo'在范围名称中为my_collection。
我得到的是我的预期:
[[<tf.Variable 'll/foo/a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'll/foo/d:0' shape=(2, 2) dtype=float32_ref>]]
答案 2 :(得分:0)
import tensorflow as tf
for var in tf.global_variables(scope='model'):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, var)
如果您感兴趣的话,也可以迭代global_variables
而不是使用trainable_variables
。在这两种情况下,您不仅要捕获使用{{1}手动创建的变量。而且还有由eg创建的那些任何get_variable()
电话。
答案 3 :(得分:0)
您可以只获取范围内的所有变量,而不要获取集合:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')