如何为tensorflow多GPU代码实现批量规范化层

时间:2017-12-26 09:44:24

标签: python tensorflow deep-learning distributed-computing batch-normalization

我创建了一个多GPU网络Cifar10_multigpu

在推理实施中,他们说:

  

我们使用tf.get_variable()而不是实例化所有变量     tf.Variable()以便跨多个GPU共享变量     训练跑。     如果我们只在单个GPU上运行此模型,我们可以简化这一点    功能     通过用tf.Variable()替换tf.get_variable()的所有实例。

所以我以所有我的conv2d图层为例,但是关于batchnorm图层呢?我如何自己实施?

在这种情况下我可以使用tensorflow.contrib.slim.batch_norm吗? 该示例不包含有关批处理规范层的任何建议。

1 个答案:

答案 0 :(得分:2)

只需使用tf.layers.batch_normalization即可。它还通过tf.get_variable()创建变量,因此也可以共享它们。

此外,它可与tf.layers.conv*函数无缝协作。

更新tf.nn.batch_normalization也很好。这是一个更低级别的功能,需要您自己管理meanvariance张贴。实际上,tf.layers.batch_normalizationtf.nn.*函数的包装器,其中还包含tf.nn.fused_batch_norm(更快的融合版本)。