我正在尝试实现Batch Normalization操作的一个小调整版本;我需要明确保持均值和方差等移动平均值。为了做到这一点,我正在对Tensorflow中的赋值和控制依赖机制进行一些实验,我遇到了一个神秘的问题。我有以下玩具代码;其中我试图测试tf.control_dependencies
是否按预期工作:
dataset = MnistDataSet(validation_sample_count=10000,
load_validation_from="validation_indices")
samples, labels, indices_list, one_hot_labels =
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE)
samples = np.expand_dims(samples, axis=3)
flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR)
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32),
trainable=False, dtype=tf.float32)
a = tf.Variable(name="a", initial_value=5.0, trainable=False)
b = tf.Variable(name="b", initial_value=4.0, trainable=False)
c = tf.Variable(name="c", initial_value=0.0, trainable=False)
batch_mean, batch_var = tf.nn.moments(flat_data, [0])
b_op = tf.assign(b, a)
mean_op = tf.assign(mean, batch_mean)
with tf.control_dependencies([b_op, mean_op]):
c = a + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples})
我只是加载一个数据批处理,每个条目都有784个维度,计算它的时刻并尝试将batch_mean
存储到变量mean
中。我也很容易将变量a
的值存储到b
中。
在最后一行中,当我为c
和mean
的值运行图表时,我将c
视为10,这是预期值。但是mean
仍然是100的向量,并且不包含批量均值。这就像mean_op = tf.assign(mean, batch_mean)
尚未执行。
这可能是什么原因?据我所知,tf.control_dependencies
调用中的所有操作必须在以下上下文中的任何操作之前执行;我在这里明确地调用c
,这是在上下文中。我错过了什么吗?
答案 0 :(得分:3)
这是tf.Session.run()
的{{3}}。 c
和mean
操作是独立的,因此可以在mean
之前评估c
(这将更新mean
)。
这是此效果的缩写版本:
a = tf.Variable(name="a", initial_value=1.0, trainable=False)
b = tf.Variable(name="b", initial_value=0.0, trainable=False)
dependent_op = tf.assign(b, a * 3)
with tf.control_dependencies([dependent_op]):
c = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([c, b]))
print(sess.run([b]))
b
的第二次评估保证会返回[3.0]
。但第一个run
可能会返回[2.0 3.0]
或[2.0 0.0]
。