我有一个训练有素的张量流模型,以检查点,.data,.meta和.index文件的形式保存。该模型使用批量标准化。
我尝试使用freeze_graph将其转换为.pb文件,可以导入为from tensorflow.python.tools import freeze_graph
。对此的输入也是.pb
文件,但只有图形结构。
我使用以下代码恢复模型
sess = tf.Session()
saver = tf.train.import_meta_graph(r'.\path\to\model\VanillaCNN.0000.meta')
saver.restore(sess, tf.train.latest_checkpoint(r'.\path\to\model'))
graph = tf.get_default_graph()
然后使用
创建包含图形结构的.pb
文件
tf.train.write_graph(sess.graph_def, "", "model_proto.pb", False)
在此之后,我使用freeze_graph
生成.pb
文件,其中包含图表结构和权重。
freeze_graph
的输入是
input_graph_path = r'.\path\to\model\model_proto.pb'
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = r'.\path\to\model\VanillaCNN.0000'
output_node_names = "VanillaCNNoutput_10/layer_output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = r'.\path\to\model\frozen_model.pb'
clear_devices = False
initializer_nodes=""
执行
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,input_binary, input_checkpoint_path,output_node_names, restore_op_name,filename_tensor_name, output_graph_path,clear_devices,initializer_nodes)
当我尝试将其加载回来时,这会创建frozen_model.pb
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="", op_dict=None, producer_op_list=None)
return graph
它会抛出以下错误
ValueError: graph_def is invalid at node 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign': Input tensor 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign/Switch:1' Cannot convert a tensor of type float32 to an input of type float32_ref.
我该如何解决这个问题?