本地变量默认情况下是否可以训练?

时间:2019-02-20 06:48:58

标签: variables tensorflow

当我浏览指南https://www.tensorflow.org/guide/variables时,我对以下说明(粗体)感到困惑:

  

默认情况下,每个tf.Variable被放置在以下两个位置   集合:

     
      
  • tf.GraphKeys.GLOBAL_VARIABLES ---可以在多个设备之间共享的变量,
  •   
  • tf.GraphKeys.TRAINABLE_VARIABLES-TensorFlow将为其计算梯度的变量。
  •   
     

如果您不希望变量是可训练的,请改为将其添加到tf.GraphKeys.LOCAL_VARIABLES集合中。例如,以下代码段演示了如何向此集合添加名为my_local的变量:

my_local = tf.get_variable("my_local", shape=(), collections [tf.GraphKeys.LOCAL_VARIABLES])`
  

或者,您可以将trainable=False指定为   tf.get_variable

my_non_trainable = tf.get_variable("my_non_trainable", shape=(), trainable=False)

但是当我创建一个局部变量时,它会自动添加到集合tf.GraphKeys.TRAINABLE_VARIABLES中,这意味着它是可训练的。那么,局部变量是否可以训练?

1 个答案:

答案 0 :(得分:1)

文档确实令人困惑。默认情况下,还将局部变量添加到可训练变量的集合中。您可以通过检查tf.trainable_variables()进行检查。因此,看起来局部变量 not 不可训练,仅将其添加到LOCAL_VARIABLES集合中是不够的,但是您需要关键字trainable=False

这是一个简短的脚本,它显示了局部变量和全局变量在训练循环中都得到了更新:

import tensorflow as tf

my_local = tf.get_variable("my_local", shape=(), collections=[tf.GraphKeys.LOCAL_VARIABLES],
                           initializer=tf.constant_initializer(1.0))
my_global = tf.get_variable("my_global", shape=(),
                            initializer=tf.constant_initializer(2.0))

target_value = tf.constant(4.0)
loss = tf.abs(my_local + my_global - target_value)
optim = tf.train.AdamOptimizer(learning_rate=1.0).minimize(loss)

for v in tf.trainable_variables():
    print(v.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print("local init: ", sess.run(my_local))
    print("global init: ", sess.run(my_global))
    for i in range(2):
        _, l = sess.run([optim, loss])
        print("loss {:.4f}".format(l))
        print("local: ", sess.run(my_local))
        print("global: ", sess.run(my_global))

可打印

my_local:0
my_global:0
local init:  1.0
global init:  2.0
loss 1.0000
local:  1.9999996
global:  2.9999995
loss 1.0000
local:  1.9473683
global:  2.9473681

如果在对my_local的调用中设置trainable=False,则tf.get_variable的值不会改变。