人
如果要在Tensorflow中实现批量标准化,则应该是以下程序。
def batchnorm_forward(X, gamma, beta):
mu = np.mean(X, axis=0)
var = np.var(X, axis=0)
X_norm = (X - mu) / np.sqrt(var + 1e-8)
out = gamma * X_norm + beta
cache = (X, X_norm, mu, var, gamma, beta)
return out, cache, mu, var
此时,
为了仅在推理阶段进行推理,它应该保存变量(均值,方差)如下。
# BatchNorm training forward propagation
h2, bn2_cache, mu, var = batchnorm_forward(h2, gamma2, beta2)
bn_params['bn2_mean'] = .9 * bn_params['bn2_mean'] + .1 * mu
bn_params['bn2_var'] = .9 * bn_params['bn2_var'] + .1 * var
仅在此推理阶段,它使用以下程序。
# BatchNorm inference forward propagation
h2 = (h2 - bn_params['bn2_mean']) / np.sqrt(bn_params['bn2_var'] + 1e-8)
h2 = gamma2 * h2 + beta2
在Tensorflow中,如何获取" bn_params [' bn2_mean']"的变量(原始值)。和" bn_params [' bn2_var']" ?
with tf.name_scope('fc1'):
w1 = weight_variable([7 * 7 * 16, 32])
h1 = tf.matmul(pool_flat2, w1)
fc1_bn = tf.contrib.layers.batch_norm(inputs = h1, is_training = phase_train)
fc1_bn_relu = tf.nn.relu(fc1_bn)
...
...
...
...
# ????? how to get variables ?????
# Image in my head
mean, var = fc1_bn.eval()
帮助我:<
答案 0 :(得分:0)
我使用Layer的面向对象版本。然后就像访问对象的属性一样简单:
>>> import tensorflow as tf
>>> bn = tf.layers.BatchNormalization()
>>> bn(tf.ones([1, 3]))
<tf.Tensor 'batch_normalization/batchnorm/add_1:0' shape=(1, 3) dtype=float32>
>>> bn.variables
[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>]
>>> bn.moving_mean
<tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>
>>> bn.moving_variance
<tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>
>>>
答案 1 :(得分:0)
面向对象的版本tf.layers.BatchNormalization()很危险,因为它没有初始参数training = True / False。您必须像这样应用它
bn_instance = tf.layers.BatchNormalization(trainable=True)
batch_layer = bn_instance.apply(input,training=True)
否则,它将在训练期间使用初始的Moving_mean(= 0)/ moving_variance(= 1)来规范化数据。 (应该仅将其用于推断)