我有一个Keras模型,正在尝试导出到ProtoBuf
最后两层看起来像这样:
features (Dense) (None, 128) 49280 concatenate_1[0][0]
__________________________________________________________________________________________________
gaze_target (Dense) (None, 2) 258 features[0][0]
我尝试像这样导出它:
sess = K.get_session()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), 'gaze_target')
graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)
此错误:
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
191
192 if isinstance(dest_nodes, six.string_types):
--> 193 raise TypeError("dest_nodes must be a list.")
194
195 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
TypeError: dest_nodes must be a list.
如何将此模型导出到ProtoBuf? (最终在SageMaker上使用)
答案 0 :(得分:0)
多亏了一位同事的工作,我们才得以解决。 graph_util.convert_variables_to_constants
方法的参数不是图层名称,而是操作名称(op.name
)。
正确的代码是:
sess = K.get_session()
outputs = [out.op.name for out in model.outputs] # Note this new line
constant_graph = graph_util.convert_variables_to_constants(sess,
sess.graph.as_graph_def(),
outputs)
graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)