Tensorflow如何将批量标准化应用于推理

时间:2017-02-13 13:13:31

标签: tensorflow normalization

您好。我对tf很新,问题是如何将批量标准化应用于推理。我在训练和测试期间应用了具有张量流的批量标准化。代码如下,您可以看到,我使用不同的均值和var值来处理训练和测试。现在培训已经完成,我想将模型应用到实践中。比如使用demo加载ckpt文件并测试一个例子。在这种情况下,我如何规范这一案例?有没有什么方法可以在每个训练时期之后保存BN的均值和变量并在以后恢复?非常感谢你!

    fc_mean, fc_var = tf.nn.moments(
            input,
            axes=[0], 
        )
        scale = tf.Variable(tf.ones([out_size]))
        shift = tf.Variable(tf.zeros([out_size]))
        epsilon = 0.001
        ema = tf.train.ExponentialMovingAverage(decay=0.5)
        def mean_var_with_update():
            ema_apply_op = ema.apply([fc_mean, fc_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(fc_mean), tf.identity(fc_var)
        mean, var = tf.cond(train_phase,
                             mean_var_with_update,
                             lambda: (ema.average(fc_mean),
                                      ema.average(fc_var)
                                      )
                             )
        input_BN = tf.nn.batch_normalization(input, mean, var, shift, scale, epsilon)

1 个答案:

答案 0 :(得分:0)

我不确定,但你们认为这会是对的吗?我保存了移位(偏移)和比例,这在训练期间产生。当我需要将模型应用于练习时,我只需使用shift和scale来替换均值和var。并将测试用例的shift和var设置为零。

    if not test:
        input_BN = tf.nn.batch_normalization(input, mean, var, shift, scale, epsilon)
    else:
        input_BN = tf.nn.batch_normalization(input, shift, scale, 0, 0, epsilon)