如何在TensorFlow图的开头“追加”Op?

时间:2017-01-28 11:13:46

标签: tensorflow

我有一个GraphDef原型文件,我使用tf.import_graph_def导入。可以在图表的末尾添加Ops,如下所示:

final_tensor = tf.import_graph_def(graph_def, name='', return_elements=['final_tensor'])
new_tensor = some_op(final_tensor)

但是我想在图表的开头添加Ops,所以graph_def中的第一个Op基本上需要将我的Op的输出作为输入,我该怎么做?

2 个答案:

答案 0 :(得分:7)

终于找到了这样做的方法。我确信评论中提到的Yarolsav函数在内部做了类似的事情。

new_input = graph_def.node.add()
new_input.op = 'new_op_name'  # eg: 'Const', 'Placeholder', 'Add' etc
new_input.name = 'some_new_name'
# set any attributes you want for new_input here
old_input.input[0] = 'some_new_name'  #  must match with the name above

有关如何设置属性的详细信息,请参阅this文件。

答案 1 :(得分:0)

@Priyatham在链接中提供的脚本是一个很好的示例,说明如何在tf graph_def中添加节点。 nameopinputattr是4个必需元素。可以分配nameop,而input应该使用extendattr应该使用CopyFrom进行分配,例如:

new_node = graph_def.node.add()
new_node.op = "Cast"
new_node.name = "To_Float"
new_node.input.extend(["To_Float"])
new_node.attr["DstT"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
new_node.attr["SrcT"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
new_node.attr["Truncate"].CopyFrom(attr_value_pb2.AttrValue(b=True))