为什么我们需要在tf.contrib.layers.batch_norm中的TensorFlow批量规范化中明确更新moving_mean和moving_variance?

时间:2017-01-30 00:14:36

标签: machine-learning tensorflow deep-learning

需要长时间阅读:如何使用tf.contrib.layers.batch_norm批处理规范化,而无需明确告知会话更新moving_statistics(moving_meanmoving_variance)不是吗?

几个月前,我提供了How could I use Batch Normalization in TensorFlow?的答案,并注意到我想解决的一些奇怪的细节。首先,我提供的实现似乎与is_training变量重复。回想一下我建议的代码:

from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm

def batch_norm_layer(x,train_phase,scope_bn):
    bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=True,
    reuse=None, # is this right?
    trainable=True,
    scope=scope_bn)
    bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=False,
    reuse=True, # is this right?
    trainable=True,
    scope=scope_bn)
    z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
    return z

在其中我有一个train_phase变量,它只包含一个tf布尔tf.placeholder(tf.bool, name='phase_train')。如您所见,它用于确定批处理范数层是否应处于推理模式。但是,变量看起来有点多余,因为我似乎有两个变量指定两次相同的东西。即train_phase中的一次和is_training中的另一次。这真的有必要吗?

我考虑了一下,似乎我可能能够使用(伪)代码删除硬编码(is_training=True/False):

from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm

def batch_norm_layer(x,train_phase,scope_bn):
    bn = batch_norm(x, decay=0.999, center=True, scale=True,
    updates_collections=None,
    is_training=get_bool(train_phase),
    reuse=None, # is this right?
    trainable=True,
    scope=scope_bn)
    z = tf.cond(train_phase, lambda: bn, lambda: bn)
    return z

似乎使train_phase变量完全冗余/愚蠢。这实际上突出了我最重要的一点,train_phase变量和tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)甚至是必要的吗?这实际上引起了我对代码的最大抱​​怨(虽然我认为这些代码甚至可能无法运行,因为在定义图表时,占位符train_phase可能甚至没有值,但你明白了。)

老实说,我发现甚至必须明确定义train_phase非常危险,因为用户似乎没有必要明确地处理Batch Norm的推理/训练模式。虽然,"正常" Batch Norm的用户应始终使用列车数据更新moving_meanmoving_variance,并且Batch Norm的任何标准用户都不应更新moving_mean,{{1随时都有测试统计信息。由于用户需要这样做:

moving_variance

它可能会给那些本来就不存在的用户带来非常糟糕的错误(至少在我看来)。此外,必须明确说明sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys, phase_train=True}) 是什么似乎很奇怪,因为每当一个人训练时,一个人使用优化器,所以当调用该代码时它应该是真的,应该非常清楚。也许这是一个糟糕的想法,但感觉就像优化器或会话应该自动设置为真,而不是依靠用户来做正确。

据我所知,有时用户可以更灵活地获得更多创意,但我真的很欣赏这一点(即使是研究人员)也是一个很好的功能。也许我只是错误地使用库或者是偏执狂,但是在使用批量规范时,用户是否真的被迫如此明确?有什么方法可以解决这个问题吗?

作为一个侧面点,让phase_train成为模型的一部分也会使代码变得比它感觉更加难看和混乱,因为在我看来,不可避免地要有一行代码会话用于检查批量标准标志是否打开。我试图避免编写的代码是逻辑:

phase_train

它只是觉得完全没必要。感觉在训练期间模型应该知道它是否应该更新变量。

作为会话中if条件的最后一个问题的快速(非常难看)解决方案,可以始终将if batch_norm: # during training sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys, phase_train=True}) else: # with no batch norm sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys}) 定义为模型的一部分(或至少作为图的一部分)并相应地设置它在适当时等于true和/或false但是当一个人实际上没有使用批量规范层时,即使我们设置了一个也不会在模型中使用 phase_train占位符它在session.run中有一个值。即,会话将其设置为真或假,但是当没有使用BN时,由于实际上没有使用BN,因此它并不重要。显然,这使得代码真的很混乱(因为一个人定义了一些甚至不需要的变量),但我似乎找不到隐藏phase_train变量的方法。目前这就是我想要的,因为在具有以下内容的行之间拆分(或复制)我的代码似乎很难看。

phase_train

以及那些没有全部的人:

sess.run(fetches=..., feed_dict={...,phase_train=False})

理想情况下,无论是否使用愚蠢的sess.run(fetches=..., feed_dict={...}) 变量,我都需要第二种解决方案,并且批处理规范可以正常工作。

1 个答案:

答案 0 :(得分:0)

我对你的问题没有完整的答案,但我有一些意见:

  • 标准做法似乎是为培训和推理构建略有不同的图表,每个图表都在启用或不启用is_training的情况下构建。
  • 设计batch_norm图层,以便您可以使用arg_scope为模型中的所有图层设置is_training=True。例如,看看如何在这里定义Inceptionv3模型: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py#L571。这至少使得在构建模型的Python代码中设置is_training一次并让它适用于所有地方会更加方便。
  • Tensorflow的底层基础设施并没有区分培训和推理时间 - 它只是运行运营商的图表。 tf.Session对神经网络,培训或推理并不了解,因此它不适合这种逻辑。
  • 可以想象Optimizer应该重写图形,以便为那些支持它的运算符启用is_training。我对此没有强烈的意见;您可以尝试提交Tensorflow Github问题,使该功能请求看到其他人对此的看法。它可能看起来有点太多了......#34;魔法"。

希望有所帮助!