批量标准化张量流

时间:2018-12-03 00:35:26

标签: python tensorflow yolo batch-normalization

我正在尝试在tensorflow中实现YOLOv3模型。 到目前为止,我已经成功地将模型从Darknet移植到tensorflow并做出了一些合理的推断。

我正在使用此script进行模型转换。

我正在使用以下代码片段来实现批处理规范化。

with tf.name_scope('batch_norm'):


    with tf.variable_scope('parameters', reuse=tf.AUTO_REUSE):

        moving_mean = tf.Variable(bn_weight_list[2],
            name='moving_mean')
        moving_variance = tf.Variable(bn_weight_list[3],
            name='moving_variance')
        beta = tf.Variable(bn_weight_list[1], name='beta')
        gamma = tf.Variable(bn_weight_list[0], name='gamma')


    def train():
        mean, variance = tf.nn.moments(x=conv_layer,
            axes=[0, 1, 2])

        # Calculating the exponential moving average for inference time
        mv_mean = tf.assign(moving_mean,
            moving_mean*momentum + mean*(1-momentum))
        mv_variance = tf.assign(moving_variance,
            moving_variance*momentum + mean*(1-momentum))


        with tf.control_dependencies([mv_mean, mv_variance]):
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mv_mean)
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mv_variance)
            return tf.identity(tf.nn.batch_normalization(x=conv_layer, 
                mean=mean, variance=variance, 
                offset=beta, scale=gamma, 
                variance_epsilon=1e-3), name='batch_norm_train')
            # return tf.identity(mean), tf.identity(variance)



    def valtest():

        # print_stmt = tf.Print(moving_mean, [moving_mean])

        return tf.identity(tf.nn.batch_normalization(x=conv_layer, 
            mean=moving_mean, variance=moving_variance, 
            offset=beta, scale=gamma, 
            variance_epsilon=1e-3), name='batch_norm_valtest')

        # return tf.identity(moving_mean), tf.identity(moving_variance)



    conv_layer = tf.case([(is_training, train)], valtest)

这里bn_weight_list是YOLO的原始作者在yolov3中提供的预训练权重的python列表.weights和is_training是张量流占位符。

现在的问题是,当我使用转换后的模型进行推理时,我得到了合理的结果,但是当我训练模型时,训练损失降低到了约15。 但是,当我测试模型时,通过设置标志is_training = False不会得到任何结果,但是将其设置为true会给我一些麻烦。 我只是想知道我的批处理规范化实施是否正确。

0 个答案:

没有答案