合并张量流图时如何保留占位符名称?

时间:2016-12-19 10:25:44

标签: python tensorflow

我想在先前创建的张量流图中添加分支。我按照mrry对这个问题(Tensorflow: How to replace a node in a calculation graph?)的回答来做到这一点,并保存了新图表的定义。

当我导入新图并尝试获取原始图的占位符时,我收到以下错误:ValueError: Requested return_element 'pool_3/_reshape:0' not found in graph_def.,但是当我使用原始图时代码工作正常。

如何维护对原始占位符的引用

我合并这两个图的代码是

with tf.Session() as sess:

# Get the b64_graph and its output tensor
resized_b64_tensor, = (tf.import_graph_def(b64_graph_def, name='',
          return_elements=[B64_OUTPUT_TENSOR_NAME+":0"]))

with gfile.FastGFile(model_filename, 'rb') as f:
  inception_graph_def = tf.GraphDef()
  inception_graph_def.ParseFromString(f.read())

  # Concatenate b64_graph and inception_graph
  g_1 = tf.import_graph_def(inception_graph_def, name='graph_name',
               input_map={RESIZED_INPUT_TENSOR_NAME : resized_b64_tensor})

  # Save joined graph
  joined_graph = sess.graph
  with gfile.FastGFile(output_graph_filename, 'wb') as f:
    f.write( joined_graph.as_graph_def().SerializeToString() )

1 个答案:

答案 0 :(得分:3)

我通过阅读这篇文章Working with multiple graphs in TensorFlow间接找到了解决方案。

当图形具有指定的名称时,它将附加到张量的名称和它包含的操作。特别是,如果我将两个图表连接成一个新图表,后者的名称将附加到以前的名称。 因此,获得张量的正确代码将是

sess.graph.get_tensor_by_name('graph_name/' + 'PreviousGraphName/PreviousTensorName:0')