如何在嵌套作用域中使用tf.get_collection()作用域过滤

时间:2018-03-21 14:29:08

标签: python tensorflow

我尝试通过在范围内定义变量并使用tf.get_collection()中的范围过滤来检索一组变量:

with tf.variable_scope('inner'):
    v = tf.get_variable(name='foo', shape=[1])
    ...
    # more variables
    ...

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'inner')
# do stuff with variables

这通常可以正常工作,但有时我的代码会被一个模块调用,该模块已经定义了自己的范围get_collection()不再找到变量:

with tf.variable_scope('outer'):
    with tf.variable_scope('inner'):
        v = tf.get_variable(name='foo', shape=[1])
        ...
        # more variables
        ...

我认为过滤是一个正则表达式,因为我可以通过在我的作用域搜索词前加.*来使get_collection()工作,但这有点hacky。有没有更好的方法来解决这个问题?

2 个答案:

答案 0 :(得分:0)

当我想训练模型时,我使用get_collection(),但在此之前,我必须恢复模型数据,如下面的代码所示:

with tf.Session() as sess:
    last_check = tf.train.latest_checkpoint(tf_data)
    saver = tf.train.import_meta_graph(last_check + '.meta')
    print (last_check +'.meta')
    saver.restore(sess, last_check)
    ######
    Model_variables = tf.GraphKeys.MODEL_VARIABLES
    Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
    ######
    all_vars = tf.get_collection(Model_variables)
    # print (all_vars)
    pesos=[]
    for i in all_vars:
        print (str(i) + '  -->  '+ str(i.eval()))

答案 1 :(得分:0)

tf.get_collcetion(key,scope =“ outer / inner”)