GraphKeys.TRAINABLE_VARIABLES
与tf.trainable_variables()
相同吗?
GraphKeys.TRAINABLE_VARIABLES
实际上是tf.GraphKeys.TRAINABLE_VARIABLES
吗?
看起来像网络成功地训练:
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())
但不使用
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss)
var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.
正如我在批处理规范化示例中看到的那样,省略了代码var_list
:
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
答案 0 :(得分:1)
如果不将var_list
传递给minimize()
函数,则将按以下方式检索变量(取自compute_gradients()
source code):
if var_list is None:
var_list = (
variables.trainable_variables() +
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
如果您还没有定义ResourceVariable
中没有的任何tf.trainable_variables()
实例,结果应该是相同的。我的猜测是问题出在其他地方。
您可以在调用minimize()
之前尝试进行一些测试,以确保没有ResourceVariable
以外的tf.trainable_variables()
:
import tensorflow as tf
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=[None, 2])
with tf.name_scope('network'):
logits = tf.layers.dense(x, units=2)
var_list = (tf.trainable_variables()
+ tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
assert set(var_list) == set(tf.trainable_variables())