tf.get_collection提取一个范围的变量

时间:2017-02-06 16:59:29

标签: python tensorflow

我有n(例如:n = 3)范围,x(例如:x = 4)没有在每个范围内定义的变量。 范围是:

model/generator_0
model/generator_1
model/generator_2

一旦我计算了损失,我想在运行时根据标准从一个范围中提取并提供所有变量。因此,我选择的范围idx的索引是转换为int32

的argmin张量
<tf.Tensor 'model/Cast:0' shape=() dtype=int32>

我已经尝试过了:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string)) 

显然不起作用。 有没有办法让所有属于该特定范围的x变量使用idx传递到优化器。

提前致谢!

Vignesh Srinivasan

1 个答案:

答案 0 :(得分:4)

你可以在TF 1.0 rc1或更高版本中执行类似的操作:

v = tf.Variable(tf.ones(()))
loss = tf.identity(v)
with tf.variable_scope('adamoptim') as vs:
   optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
print([v.name for v in optim_vars]) #=> prints lists of vars created