如何获得TRAINABLE VARIABLES,其名称包含在TensorFlow中重新训练的特定字符?

时间:2017-05-26 14:05:29

标签: tensorflow

在TensorFlow中,我们可以使用tf.get_collection来获取具有特定前缀的变量。但是,如何才能获得名称中包含特定字符的变量,以便执行重新训练等任务?

玩具示例代码

import tensorflow as tf
with tf.variable_scope('net'):
    var_1 = tf.Variable(tf.random_normal([3, 5],stddev=0.35),name='var1')
    with tf.variable_scope('retrain'):
        var_2 = tf.Variable(tf.random_normal([3, 5], stddev=0.35),name='var2')
        var_3 = tf.Variable(tf.zeros([5]), name="var3")

在此示例中,print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"net"))将返回所有可训练的变量。

但是,print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"retrain"))将返回空列表[],而不是var_2var_3

假设在实际条件下,精确的根变量范围net可以是netnet_1 ...除了打印所有可训练变量或使用tensorboard之外找到变量名称的前缀(以确定它是net\retrain还是net_1\retrain),我们可以使用tf.get_collection之类的函数来获取{{1 }和var_2

2 个答案:

答案 0 :(得分:0)

范围是嵌套的。你想用

tf.get_collection(tf.GraphKeys.VARIABLES, "net/retrain")

答案 1 :(得分:0)

您可以在范围retrain内获取变量。

tf.get_collection(tf.GraphKeys.VARIABLES, scope='net/retrain').