使用布尔张量进行批量标准化更新

时间:2016-10-05 13:59:33

标签: python tensorflow

我知道在Tensorflow中使用批量规范化是一个已解决的问题,但在采用可满足我需求的代码时遇到了一些我不太了解的事情。

问题:似乎在使用条件分支时,将分支放在cond()函数中的顺序很重要。

我把以下最小的例子放在一起:

import tensorflow as tf

inpt = tf.placeholder(tf.float32, shape=[1])
do_update = tf.placeholder(tf.bool)

ema = tf.train.ExponentialMovingAverage(0.5)

def update():
    ema_assign = ema.apply([inpt])
    with tf.control_dependencies([ema_assign]):
        return tf.identity(ema.average(inpt))

def no_update():
    return ema.average(inpt)

run = tf.python.control_flow_ops.cond(do_update, update, no_update) # this works
# run = tf.python.control_flow_ops.cond(do_update, no_update, update) # this doesn't work
# run = tf.python.control_flow_ops.cond(tf.logical_not(do_update), no_update, update) # this doesn't work as well

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

for _ in range(10):
    run_v = sess.run([run], feed_dict={do_update: True, inpt: [1.0]})
    print run_v[0]

run_v = sess.run([run], feed_dict={do_update: False, inpt: [1000.0]})
print run_v[0]

第一行 run = ... 不是tensorflow的问题,对我来说是输出(如预期的那样):

[ 0.5]
[ 0.75]
[ 0.875]
[ 0.9375]
[ 0.96875]
[ 0.984375]
[ 0.9921875]
[ 0.99609375]
[ 0.99804688]
[ 0.99902344]
[ 0.99902344]

然而,对于第二行和第三行,我得到一个ValueError:

ValueError: fn1 and fn2 must return the same number of results.

似乎重要的是fn1分支是包含平均更新的分支,但我不知道为什么这是有道理的。有关这种行为的任何想法吗?

0 个答案:

没有答案