tf.layers.batch_normalization中“可训练”和“训练”标志的重要性

时间:2018-05-07 07:42:56

标签: tensorflow batch-normalization

tf.layers.batch_normalization中“可训练”和“训练”标志的意义是什么?在训练和预测期间,这两者有何不同?

3 个答案:

答案 0 :(得分:4)

批次规范分为两个阶段:

1. Training:
   -  Normalize layer activations using `moving_avg`, `moving_var`, `beta` and `gamma` 
     (`training`* should be `True`.)
   -  update the `moving_avg` and `moving_var` statistics. 
     (`trainable` should be `True`)
2. Inference:
   -  Normalize layer activations using `beta` and `gamma`.
      (`training` should be `False`)

用于说明少数情况的示例代码:

#random image
img = np.random.randint(0,10,(2,2,4)).astype(np.float32)

# batch norm params initialized
beta = np.ones((4)).astype(np.float32)*1 # all ones 
gamma = np.ones((4)).astype(np.float32)*2 # all twos
moving_mean = np.zeros((4)).astype(np.float32) # all zeros
moving_var = np.ones((4)).astype(np.float32) # all ones

#Placeholders for input image
_input = tf.placeholder(tf.float32, shape=(1,2,2,4), name='input')

#batch Norm
out = tf.layers.batch_normalization(
       _input,
       beta_initializer=tf.constant_initializer(beta),
       gamma_initializer=tf.constant_initializer(gamma),
       moving_mean_initializer=tf.constant_initializer(moving_mean),
       moving_variance_initializer=tf.constant_initializer(moving_var),
       training=False, trainable=False)


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
init_op = tf.global_variables_initializer()

 ## 2. Run the graph in a session 

 with tf.Session() as sess:

    # init the variables   
    sess.run(init_op)

    for i in range(2):
        ops, o = sess.run([update_ops, out], feed_dict={_input: np.expand_dims(img, 0)})
        print('beta', sess.run('batch_normalization/beta:0'))
        print('gamma', sess.run('batch_normalization/gamma:0'))
        print('moving_avg',sess.run('batch_normalization/moving_mean:0'))
        print('moving_variance',sess.run('batch_normalization/moving_variance:0'))
        print('out', np.round(o))
        print('')

training=Falsetrainable=False

  img = [[[4., 5., 9., 0.]...
  out = [[ 9. 11. 19.  1.]... 
  The activation is scaled/shifted using gamma and beta.

training=Truetrainable=False

  out = [[ 2.  2.  3. -1.] ...
  The activation is normalized using `moving_avg`, `moving_var`, `gamma` and `beta`. 
  The averages are not updated.

traning=Truetrainable=True

  The out is same as above, but the `moving_avg` and `moving_var` gets updated to new values.

  moving_avg [0.03249997 0.03499997 0.06499994 0.02749997]
  moving_variance [1.0791667 1.1266665 1.0999999 1.0925]

答案 1 :(得分:1)

training控制是否使用训练模式batchnorm(使用来自此minibatch的统计数据)或推理模式batchnorm(使用训练数据中的平均统计数据)。 trainable控制在batchnorm过程中创建的变量本身是否可以训练。

答案 2 :(得分:0)

这很复杂。 在TF 2.0中,行为已更改,请参见:

https://github.com/tensorflow/tensorflow/blob/095272a4dd259e8acd3bc18e9eb5225e7a4d7476/tensorflow/python/keras/layers/normalization_v2.py#L26

  

关于在layer.trainable = False层上设置BatchNormalization

     

设置layer.trainable = False的含义是冻结   层,即其内部状态在训练期间不会改变:
  其可训练权重在fit()期间不会更新,或者   train_on_batch(),其状态更新将不会运行。通常,   这并不一定意味着该层以推理方式运行
  模式(通常由training参数控制,该参数可以   在调用图层时被传递)。 “冻结状态”和“推理模式”
  是两个不同的概念。

     

但是,对于BatchNormalization层,设置
  图层上的trainable = False表示该图层将
  随后以推断模式运行
(表示它将使用   移动平均值和移动方差以规范当前批次,
  而不是使用当前批次的均值和方差)。这个   TensorFlow 2.0中已引入行为,以启用   layer.trainable = False以产生最普遍期望的   卷积微调用例中的行为。请注意:

     
      
  • 此行为仅在TensorFlow 2.0以后出现。在1. *中,设置layer.trainable = False将冻结图层,但不会冻结   切换到推理模式。
  •   
  • 在包含其他图层的模型上设置trainable会递归设置所有内部图层的trainable值。
  •   
  • 如果在模型上调用trainable后更改了compile()属性的值,则新值对此无效   直到再次调用compile()为止。
  •