我在尝试对mnist位数数据集实施tf.nn.batch_normalization时遇到问题。
tf.VERSION:1.13.0-rc2 操作系统:MacBook Pro(13英寸,2017年)上的MacOS Mojave 10.14.4
出于测试目的,我实现了一个非常简单的网络
当我使用tf.nn.batch_normalization
Ybn1 = tf.nn.batch_normalization(Yl1, m1, v1, O1, S1, 1e-5)
这会导致分歧:
https://user-images.githubusercontent.com/44782534/56353537-df10f380-61d1-11e9-9060-749a04f55cb4.png
如果我自己做数学,那它将收敛:
Yhat1 = (Yl1 - m1) / tf.sqrt(v1 + 1e-5)
Ybn1 = S1 * Yhat1 + O1
https://user-images.githubusercontent.com/44782534/56353771-6eb6a200-61d2-11e9-9d4f-08a6096091ea.png
如果我在代码中实现了batch_normalization函数的内容,那么它也不起作用。
inv = math_ops.rsqrt(v1 + 1e-5)
inv *= S1
Ybn1 = Yl1 * math_ops.cast(inv, Yl1.dtype) + math_ops.cast(O1 - m1 * inv, Yl1.dtype)
但是如果我将最后两行结合起来,它将正常工作:
inv = math_ops.rsqrt(v1 + 1e-5)
Ybn1 = (Yl1 - m1) * inv * S1+ O1
我当然做错了,但是我不知道是什么,欢迎任何帮助:)
如果要重现此问题,请使用以下完整代码: https://github.com/neodelphis/tensorflow-without-a-phd-french/blob/master/mnist_test.ipynb
谢谢