如何将范围内的所有变量添加到特定集合

时间:2016-10-26 09:01:31

标签: tensorflow

在tensorflow python API中, tf.get_variable 具有参数集合,以将创建的var添加到指定的集合。但是 tf.variable_scope 却没有。 将变量范围内的所有变量添加到某个集合中的建议方法是什么?

4 个答案:

答案 0 :(得分:1)

我不相信有办法直接这样做。您可以在Tensorflow的github问题跟踪器上提交功能请求。

我可以建议您尝试两种解决方法:

  • 迭代tf.all_variables()的结果,并提取名称类似于".../scope_name/..."的变量。范围名称以变量名称编码,由/个字符分隔。

  • 围绕tf.VariableScope和tf.get_variable()编写包装器,它将在范围内创建的变量存储在数据结构中。

我希望有所帮助!

答案 1 :(得分:0)

我设法做到了:

import tensorflow as tf

def var_1():
    with tf.variable_scope("foo") as foo_scope:
        assert foo_scope.name == "ll/foo"
        a = tf.get_variable("a", [2, 2])
    return foo_scope

def var_2(foo_scope):
    with tf.variable_scope("bar"):
        b = tf.get_variable("b", [2, 2])
        with tf.variable_scope("baz") as other_scope:
            c = tf.get_variable("c", [2, 2])
            assert other_scope.name == "ll/bar/baz"
            with tf.variable_scope(foo_scope) as foo_scope2:
                d = tf.get_variable("d", [2, 2])
                assert foo_scope2.name == "ll/foo"  # Not changed.

def main():
    with tf.variable_scope("ll"):
        scp = var_1()
        var_2(scp)
        all_default_global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
        my_collection = tf.get_collection('my_collection') # create my collection
        ll_foo_variables = []
        for variable in all_default_global_variables:
            if "ll/foo" in variable.name:
                ll_foo_variables.append(variable)
        tf.add_to_collection('my_collection', ll_foo_variables)

        variables_in_my_collection = tf.get_collection_ref("my_collection")
        print(variables_in_my_collection)

main()

您可以在a,b,c和d的代码中看到,只有a和d具有相同的范围名称ll/foo

过程:

首先,我添加默认情况下在tf.GraphKeys.GLOBAL_VARIABLES集合中创建的所有变量,然后创建一个名为my_collection的集合,然后我只添加那些带有' ll / foo'在范围名称中为my_collection。

我得到的是我的预期:

[[<tf.Variable 'll/foo/a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'll/foo/d:0' shape=(2, 2) dtype=float32_ref>]]

答案 2 :(得分:0)

import tensorflow as tf    

for var in tf.global_variables(scope='model'):
    tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, var)

如果您感兴趣的话,也可以迭代global_variables而不是使用trainable_variables。在这两种情况下,您不仅要捕获使用{{1}手动创建的变量。而且还有由eg创建的那些任何get_variable()电话。

答案 3 :(得分:0)

您可以只获取范围内的所有变量,而不要获取集合:

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')

https://stackoverflow.com/a/36536063/9095840