导出的Tensorflow模型不保留占位符形状

时间:2016-11-22 04:24:25

标签: tensorflow tensorflow-serving

我正在使用exporter中的tensorflow.contrib.session_bundle来保存我的模型:

x = tf.placeholder(tf.float32, (None,) + (100, 200) + (1,))
....
saver = tf_saver.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
model_exporter.init(
    sess.graph.as_graph_def(),
    named_graph_signatures={
        'inputs': exporter.generic_signature({'images': x}),
        'outputs': exporter.generic_signature({'classes': y})})

然后我将模型加载回(session_bundle中的tensorflow.contrib.session_bundle):

sess, meta_graph_def = session_bundle.load_session_bundle_from_path(input)

但是当我检查对应于输入x的占位符张量时,我看不到形状信息:

> sess.graph.get_tensor_by_name(input_name)
<tf.Tensor 'Placeholder:0' shape=<unknown> dtype=float32>

这是设计还是有一些导致形状丢失的错误?

1 个答案:

答案 0 :(得分:0)

以下是同事的回答:

exporter.generic_signature调用(构建named_graph_signatures时)填充此处定义的generic_signature地图:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L69

地图中的值是TensorBinding,它本身就是张量名称。见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L20

因此,预计不会保留形状,名称应足以识别张量。“