Keras的BatchNormalization层中的Moving_mean和Moving_variance

时间:2017-08-01 11:45:18

标签: tensorflow keras batch-normalization

我想将一组预先训练过的重量从Tensorflow输出到Keras。 问题是Tensorflow中的批量标准化层仅将Beta和Gamma嵌入到可训练的权重中,而在Keras中,我们也有Moving_mean和Moving_variance。 我很困惑从哪里获得这些重量。

1 个答案:

答案 0 :(得分:1)

试试tf.train.NewCheckpointReader。我最近将CNN模型从TF转换为Keras,并且用它导出移动均值/方差权重没有问题。

reader = tf.train.NewCheckpointReader(ckpt_file)
for key in reader.get_variable_to_shape_map():
    path = os.path.join(output_folder, get_filename(key))
    arr = reader.get_tensor(key)
    np.save(path, arr)
    print("tensor_name: ", key)

其中get_filename()只是将张量名称转换为正确文件名的函数。 (例如,用下划线替换斜线)

如果您对更多细节感兴趣,full code可能会有所帮助。