我想使用冻结图来修改/更改操作。 例如,如下所示,将操作从“调整大小方法为ResizeNearestNeighbor的upsample2D”更改为“调整大小方法为ResizeBilinear的upsample2D”
with tf.device('/gpu:0'):
with tf.gfile.GFile(filename, 'rb') as file:
serialized_graph = file.read()
graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(graph_def, name='')
graph_replace = tf.contrib.graph_editor.graph_replace
nodes = graph_def.node
for node in nodes:
if "ResizeNearestNeighbor" in node.name :
print ("===========> ", node.name)
node.op ="ResizeBilinear"
# also need to change node name
nodes = graph_def.node
for node in nodes:
print (node.name)
tf.train.write_graph(graph_def, "./", name='modified.pb')
实际上,以上代码无法正常工作;我认为这是由于nodedef中的哈希类型所致;另外,解码错误google.protobuf.message.DecodeError:导入修改后的图形时解析消息时出错
我认为以下方法可能有效,但是对此有帮助吗?
graph_replace = tf.contrib.graph_editor.graph_replace
graph_replace(node, {node.xx: new_node.xx })
或者
tf.import_graph_def(graph_def, input_map={node: a new node})
谢谢
答案 0 :(得分:0)
您的主代码块有些奇怪:
graph_replace
,但从不使用它。node.name
上进行了匹配。名称几乎可以是任何东西。您应该在node.op
上进行匹配,node.op ="ResizeBilinear"
是操作的“类型”。这些名称是固定的。char a; boom = (uint64) a
中的节点类型。这听起来不对。它类似于C中的graph_editor
。您不能只更改某些内容的“类型”。通常,手动修改GraphDef是一个坏主意。它不是公共界面的一部分,可以随时更改。
使用Transformer
可能是最好的方法。您可以使用图形transform_op_handler
并覆盖ValidateOnParse
。有关使用Transformer的基本示例,请参见此test。您可以将处理程序基于默认的one,该默认值仅按原样复制节点。如果有帮助,请在此处使用此变压器的place。
答案 1 :(得分:0)
我刚刚遇到了类似的问题并找到了解决方案。我不知道是否可以仅重命名操作,因此我想您需要完全交换节点。
要解决此问题,您需要定义一个新的操作,如下所示:
output_tensor= tf.image.resize_images(input_tensor, [300, 300], method=tf.image.ResizeMethod.BILINEAR)
然后使用
tf.import_graph_def(graph_model_def, name='', input_map={"existing_input_tensor": input_tensor}, return_elements=['data/inputs:0'])
Here是更详细的说明。