GraphKeys.TRAINABLE_VARIABLES与tf.trainable_variables()

时间:2019-04-10 18:24:26

标签: python tensorflow

GraphKeys.TRAINABLE_VARIABLEStf.trainable_variables()相同吗?

GraphKeys.TRAINABLE_VARIABLES实际上是tf.GraphKeys.TRAINABLE_VARIABLES吗?

看起来像网络成功地训练:

optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())

但不使用

optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss)

根据documentation

var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

正如我在批处理规范化示例中看到的那样,省略了代码var_list

  x_norm = tf.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

1 个答案:

答案 0 :(得分:1)

如果不将var_list传递给minimize()函数,则将按以下方式检索变量(取自compute_gradients() source code):

if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

如果您还没有定义ResourceVariable中没有的任何tf.trainable_variables()实例,结果应该是相同的。我的猜测是问题出在其他地方。

您可以在调用minimize()之前尝试进行一些测试,以确保没有ResourceVariable以外的tf.trainable_variables()

import tensorflow as tf

with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, shape=[None, 2])
    with tf.name_scope('network'):
        logits = tf.layers.dense(x, units=2)

    var_list = (tf.trainable_variables()
                + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    assert set(var_list) == set(tf.trainable_variables())