我已经基于此repo训练了Keras模型。
训练后,我将模型保存为如下检查点文件:
sess=tf.keras.backend.get_session()
saver = tf.train.Saver()
saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))
然后,我从检查点文件中还原图形,并使用标准的tf Frozen_graph脚本将其冻结。当我想恢复冻结的图时,出现以下错误:
Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource
如何解决此问题?
编辑:我的问题与this问题有关。不幸的是,我无法使用替代方法。
编辑2: 我在github上打开了一个问题,并创建了一个要点来重现该错误。 https://github.com/keras-team/keras/issues/11032
答案 0 :(得分:8)
只需解决相同的问题。我联系了以下几个答案:1,2,3,并意识到这个问题源于 batchnorm层的工作状态:培训或学习。因此,为了解决该问题,您只需要在加载模型之前放置一行即可:
keras.backend.set_learning_phase(0)
完整示例,以导出模型
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3
def freeze_graph(graph, session, output):
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)
tf.keras.backend.set_learning_phase(0) # this line most important
base_model = InceptionV3()
session = tf.keras.backend.get_session()
INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])
加载* .pb模型:
from PIL import Image
import numpy as np
import tensorflow as tf
# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]
graph_def = tf.GraphDef()
with tf.gfile.GFile("frozen_model.pb", "rb") as f:
graph_def.ParseFromString(f.read())
graph = tf.Graph()
with graph.as_default():
net_inp, net_out = tf.import_graph_def(
graph_def, return_elements=["input_1", "predictions/Softmax"]
)
with tf.Session(graph=graph) as sess:
out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
print(np.argmax(out))
答案 1 :(得分:1)
这是Tensorflow 1.1x的错误,另外一个答案是,这是因为内部批处理规范学习与推理状态有关。在TF 1.14.0中,尝试进行freeze a batch norm layer时实际上会收到一个神秘的错误。
使用set_learning_phase(0)
会将批处理规范层(可能还有其他对象,如辍学)置于推理模式,因此批处理规范层在训练期间将不起作用,从而导致准确性降低。
我的解决方法是:
K.set_learning_phase(0)
使用 ):def create_model():
inputs = Input(...)
...
return model
model = create_model()
model.save_weights("weights.h5")
K.clear_session()
K.set_learning_phase(0)
model = create_model()
model.load_weights("weights.h5")
答案 2 :(得分:0)
感谢您指出主要问题!我发现keras.backend.set_learning_phase(0)
有时不起作用,至少就我而言。
另一种方法可能是:for l in keras_model.layers: l.trainable = False