无法使用BatchNorm层导入冻结图

时间:2018-08-15 11:53:52

标签: tensorflow keras

我已经基于此repo训练了Keras模型。

训练后,我将模型保存为如下检查点文件:

 sess=tf.keras.backend.get_session() 
 saver = tf.train.Saver()
 saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))

然后,我从检查点文件中还原图形,并使用标准的tf Frozen_graph脚本将其冻结。当我想恢复冻结的图时,出现以下错误:

Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource

如何解决此问题?

编辑:我的问题与this问题有关。不幸的是,我无法使用替代方法。

编辑2: 我在github上打开了一个问题,并创建了一个要点来重现该错误。 https://github.com/keras-team/keras/issues/11032

3 个答案:

答案 0 :(得分:8)

只需解决相同的问题。我联系了以下几个答案:123,并意识到这个问题源于 batchnorm层的工作状态:培训或学习。因此,为了解决该问题,您只需要在加载模型之前放置一行即可:

keras.backend.set_learning_phase(0)

完整示例,以导出模型

import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
    with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
        graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()

session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])

加载* .pb模型:

from PIL import Image
import numpy as np
import tensorflow as tf

# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]

graph_def = tf.GraphDef()

with tf.gfile.GFile("frozen_model.pb", "rb") as f:
    graph_def.ParseFromString(f.read())

graph = tf.Graph()

with graph.as_default():
    net_inp, net_out = tf.import_graph_def(
        graph_def, return_elements=["input_1", "predictions/Softmax"]
    )
    with tf.Session(graph=graph) as sess:
        out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
        print(np.argmax(out))

答案 1 :(得分:1)

这是Tensorflow 1.1x的错误,另外一个答案是,这是因为内部批处理规范学习与推理状态有关。在TF 1.14.0中,尝试进行freeze a batch norm layer时实际上会收到一个神秘的错误。

使用set_learning_phase(0)会将批处理规范层(可能还有其他对象,如辍学)置于推理模式,因此批处理规范层在训练期间将不起作用,从而导致准确性降低。

我的解决方法是:

  1. 使用函数创建模型(不要使用K.set_learning_phase(0)使用 ):
def create_model():
    inputs = Input(...)
    ...
    return model

model = create_model()
  1. 火车模型
  2. 节省体重: model.save_weights("weights.h5")
  3. 清除会话(重要的是层名称相同)并将学习阶段设置为0:
K.clear_session()
K.set_learning_phase(0)
  1. 重新创建模型和负载权重:
model = create_model()
model.load_weights("weights.h5")
  1. 像以前一样冻结

答案 2 :(得分:0)

感谢您指出主要问题!我发现keras.backend.set_learning_phase(0)有时不起作用,至少就我而言。

另一种方法可能是:for l in keras_model.layers: l.trainable = False