我知道在Tensorflow中使用批量规范化是一个已解决的问题,但在采用可满足我需求的代码时遇到了一些我不太了解的事情。
问题:似乎在使用条件分支时,将分支放在cond()函数中的顺序很重要。
我把以下最小的例子放在一起:
import tensorflow as tf
inpt = tf.placeholder(tf.float32, shape=[1])
do_update = tf.placeholder(tf.bool)
ema = tf.train.ExponentialMovingAverage(0.5)
def update():
ema_assign = ema.apply([inpt])
with tf.control_dependencies([ema_assign]):
return tf.identity(ema.average(inpt))
def no_update():
return ema.average(inpt)
run = tf.python.control_flow_ops.cond(do_update, update, no_update) # this works
# run = tf.python.control_flow_ops.cond(do_update, no_update, update) # this doesn't work
# run = tf.python.control_flow_ops.cond(tf.logical_not(do_update), no_update, update) # this doesn't work as well
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
for _ in range(10):
run_v = sess.run([run], feed_dict={do_update: True, inpt: [1.0]})
print run_v[0]
run_v = sess.run([run], feed_dict={do_update: False, inpt: [1000.0]})
print run_v[0]
第一行 run = ... 不是tensorflow的问题,对我来说是输出(如预期的那样):
[ 0.5]
[ 0.75]
[ 0.875]
[ 0.9375]
[ 0.96875]
[ 0.984375]
[ 0.9921875]
[ 0.99609375]
[ 0.99804688]
[ 0.99902344]
[ 0.99902344]
然而,对于第二行和第三行,我得到一个ValueError:
ValueError: fn1 and fn2 must return the same number of results.
似乎重要的是fn1分支是包含平均更新的分支,但我不知道为什么这是有道理的。有关这种行为的任何想法吗?