将Keras模型以(无,2)输出形状导出到protobuf

时间:2020-01-24 04:22:43

标签: python tensorflow keras protocol-buffers amazon-sagemaker

我有一个Keras模型,正在尝试导出到ProtoBuf

最后两层看起来像这样:

features (Dense)                (None, 128)          49280       concatenate_1[0][0]              
__________________________________________________________________________________________________
gaze_target (Dense)             (None, 2)            258         features[0][0]      

我尝试像这样导出它:

sess = K.get_session()

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), 'gaze_target')
graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)

此错误:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
    191 
    192   if isinstance(dest_nodes, six.string_types):
--> 193     raise TypeError("dest_nodes must be a list.")
    194 
    195   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(

TypeError: dest_nodes must be a list.

如何将此模型导出到ProtoBuf? (最终在SageMaker上使用)

1 个答案:

答案 0 :(得分:0)

多亏了一位同事的工作,我们才得以解决。 graph_util.convert_variables_to_constants方法的参数不是图层名称,而是操作名称(op.name)。

正确的代码是:

sess = K.get_session()

outputs = [out.op.name for out in model.outputs] # Note this new line

constant_graph = graph_util.convert_variables_to_constants(sess, 
                                             sess.graph.as_graph_def(), 
                                             outputs)

graph_io.write_graph(constant_graph, 'export', 'output.pb', as_text=False)