我正在使用exporter
中的tensorflow.contrib.session_bundle
来保存我的模型:
x = tf.placeholder(tf.float32, (None,) + (100, 200) + (1,))
....
saver = tf_saver.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
model_exporter.init(
sess.graph.as_graph_def(),
named_graph_signatures={
'inputs': exporter.generic_signature({'images': x}),
'outputs': exporter.generic_signature({'classes': y})})
然后我将模型加载回(session_bundle
中的tensorflow.contrib.session_bundle
):
sess, meta_graph_def = session_bundle.load_session_bundle_from_path(input)
但是当我检查对应于输入x的占位符张量时,我看不到形状信息:
> sess.graph.get_tensor_by_name(input_name)
<tf.Tensor 'Placeholder:0' shape=<unknown> dtype=float32>
这是设计还是有一些导致形状丢失的错误?
答案 0 :(得分:0)
以下是同事的回答:
“exporter.generic_signature
调用(构建named_graph_signatures
时)填充此处定义的generic_signature
地图:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L69
地图中的值是TensorBinding
,它本身就是张量名称。见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L20
因此,预计不会保留形状,名称应足以识别张量。“