我使用tensorflow定义了一个深层CNN,包括一个批处理规范化操作,即我的代码可能像这样:
def network(input):
...
input = tf.layers.batch_normalization(input, ...)
...
假定网络已接受培训,并且检查点文件已保存。现在,我想使用此模型进行推断。通常,除了将参数network(input)
传递到training=False
之外,我可以再次调用函数tf.layers.batch_normalization()
,然后从检查点文件恢复权重。
但是,由于可以更改功能tf.import_meta_graph
中的代码,我更愿意使用network(input)
来重建我的网络。
但是现在我该如何在推理模式下设置批处理规范化操作?由于我无权使用函数tf.layers.batch_normalization()
,因此解决该问题有些困难。