在TensorFlow中,我们可以使用tf.get_collection
来获取具有特定前缀的变量。但是,如何才能获得名称中包含特定字符的变量,以便执行重新训练等任务?
玩具示例代码
import tensorflow as tf
with tf.variable_scope('net'):
var_1 = tf.Variable(tf.random_normal([3, 5],stddev=0.35),name='var1')
with tf.variable_scope('retrain'):
var_2 = tf.Variable(tf.random_normal([3, 5], stddev=0.35),name='var2')
var_3 = tf.Variable(tf.zeros([5]), name="var3")
在此示例中,print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"net"))
将返回所有可训练的变量。
但是,print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,"retrain"))
将返回空列表[]
,而不是var_2
和var_3
。
假设在实际条件下,精确的根变量范围net
可以是net
,net_1
...除了打印所有可训练变量或使用tensorboard
之外找到变量名称的前缀(以确定它是net\retrain
还是net_1\retrain
),我们可以使用tf.get_collection
之类的函数来获取{{1 }和var_2
?
答案 0 :(得分:0)
范围是嵌套的。你想用
tf.get_collection(tf.GraphKeys.VARIABLES, "net/retrain")
答案 1 :(得分:0)
您可以在范围retrain
内获取变量。
tf.get_collection(tf.GraphKeys.VARIABLES, scope='net/retrain').