需要长时间阅读:如何使用tf.contrib.layers.batch_norm
批处理规范化,而无需明确告知会话更新moving_statistics(moving_mean
和moving_variance
)不是吗?
几个月前,我提供了How could I use Batch Normalization in TensorFlow?的答案,并注意到我想解决的一些奇怪的细节。首先,我提供的实现似乎与is_training
变量重复。回想一下我建议的代码:
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
def batch_norm_layer(x,train_phase,scope_bn):
bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=True,
reuse=None, # is this right?
trainable=True,
scope=scope_bn)
bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=False,
reuse=True, # is this right?
trainable=True,
scope=scope_bn)
z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
return z
在其中我有一个train_phase
变量,它只包含一个tf布尔tf.placeholder(tf.bool, name='phase_train')
。如您所见,它用于确定批处理范数层是否应处于推理模式。但是,变量看起来有点多余,因为我似乎有两个变量指定两次相同的东西。即train_phase
中的一次和is_training
中的另一次。这真的有必要吗?
我考虑了一下,似乎我可能能够使用(伪)代码删除硬编码(is_training=True/False
):
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
def batch_norm_layer(x,train_phase,scope_bn):
bn = batch_norm(x, decay=0.999, center=True, scale=True,
updates_collections=None,
is_training=get_bool(train_phase),
reuse=None, # is this right?
trainable=True,
scope=scope_bn)
z = tf.cond(train_phase, lambda: bn, lambda: bn)
return z
似乎使train_phase
变量完全冗余/愚蠢。这实际上突出了我最重要的一点,train_phase
变量和tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
甚至是必要的吗?这实际上引起了我对代码的最大抱怨(虽然我认为这些代码甚至可能无法运行,因为在定义图表时,占位符train_phase可能甚至没有值,但你明白了。)
老实说,我发现甚至必须明确定义train_phase
非常危险,因为用户似乎没有必要明确地处理Batch Norm的推理/训练模式。虽然,"正常" Batch Norm的用户应始终使用列车数据更新moving_mean
,moving_variance
,并且Batch Norm的任何标准用户都不应更新moving_mean
,{{1随时都有测试统计信息。由于用户需要这样做:
moving_variance
它可能会给那些本来就不存在的用户带来非常糟糕的错误(至少在我看来)。此外,必须明确说明sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys, phase_train=True})
是什么似乎很奇怪,因为每当一个人训练时,一个人使用优化器,所以当调用该代码时它应该是真的,应该非常清楚。也许这是一个糟糕的想法,但感觉就像优化器或会话应该自动设置为真,而不是依靠用户来做正确。
据我所知,有时用户可以更灵活地获得更多创意,但我真的很欣赏这一点(即使是研究人员)也是一个很好的功能。也许我只是错误地使用库或者是偏执狂,但是在使用批量规范时,用户是否真的被迫如此明确?有什么方法可以解决这个问题吗?
作为一个侧面点,让phase_train
成为模型的一部分也会使代码变得比它感觉更加难看和混乱,因为在我看来,不可避免地要有一行代码会话用于检查批量标准标志是否打开。我试图避免编写的代码是逻辑:
phase_train
它只是觉得完全没必要。感觉在训练期间模型应该知道它是否应该更新变量。
作为会话中if条件的最后一个问题的快速(非常难看)解决方案,可以始终将if batch_norm:
# during training
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys, phase_train=True})
else:
# with no batch norm
sess.run(fetches=train_step, feed_dict={x: batch_xs, y_: batch_ys})
定义为模型的一部分(或至少作为图的一部分)并相应地设置它在适当时等于true和/或false但是当一个人实际上没有使用批量规范层时,即使我们设置了一个也不会在模型中使用 phase_train
占位符它在session.run中有一个值。即,会话将其设置为真或假,但是当没有使用BN时,由于实际上没有使用BN,因此它并不重要。显然,这使得代码真的很混乱(因为一个人定义了一些甚至不需要的变量),但我似乎找不到隐藏phase_train
变量的方法。目前这就是我想要的,因为在具有以下内容的行之间拆分(或复制)我的代码似乎很难看。
phase_train
以及那些没有全部的人:
sess.run(fetches=..., feed_dict={...,phase_train=False})
理想情况下,无论是否使用愚蠢的sess.run(fetches=..., feed_dict={...})
变量,我都需要第二种解决方案,并且批处理规范可以正常工作。
答案 0 :(得分:0)
我对你的问题没有完整的答案,但我有一些意见:
is_training
的情况下构建。batch_norm
图层,以便您可以使用arg_scope
为模型中的所有图层设置is_training=True
。例如,看看如何在这里定义Inceptionv3模型:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py#L571。这至少使得在构建模型的Python代码中设置is_training
一次并让它适用于所有地方会更加方便。tf.Session
对神经网络,培训或推理并不了解,因此它不适合这种逻辑。Optimizer
应该重写图形,以便为那些支持它的运算符启用is_training
。我对此没有强烈的意见;您可以尝试提交Tensorflow Github问题,使该功能请求看到其他人对此的看法。它可能看起来有点太多了......#34;魔法"。希望有所帮助!