我正在尝试从图中删除一些节点并将其保存在.pb
中仅需要的节点可以添加到新的mod_graph_def
图中,但是该图在其他节点输入中仍然具有对已删除节点的某些引用的问题,但是我无法修改节点的输入:
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# The problem that graph still have some references to deleted node in other nodes inputs
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
node.input = inp_names # TypeError: Can't set composite field
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
答案 0 :(得分:0)
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Delete nodes
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# Delete references to deleted nodes
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
del node.input[:]
node.input.extend(inp_names)
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
答案 1 :(得分:0)
上一个答案很好,但我建议将删除的节点输入与下一个节点输入绑定。就像如果我们有一个链A-input b->B-input c->C-input d->D
并要删除说节点B
一样,我们不仅应该删除input c
,还应将其替换为input b
。
看下面的代码:
# remove node and connect its input to follower
def remove_node(graph_def, node_name, input_name):
nodes = []
for node in graph_def.node:
if node.name == node_name:
assert(input_name in node.input or len(node.input) == 0),\
"Node input to use is not among inputs of node to remove"
input_of_removed_node = input_name if len(node.input) else ''
print("Removing {} and using its input {}".format(node.name,
input_of_removed_node))
continue
nodes.append(node)
# modify inputs where required
# removed name must be replaced with input from removed node
for node in nodes:
inp_names = []
replace = False
for inp in node.input:
if inp == node_name:
inp_names.append(input_of_removed_node)
print("For node {} replacing input {}
with {}".format(node.name, inp, input_of_removed_node))
replace = True
else:
inp_names.append(inp)
if replace:
del node.input[:]
node.input.extend(inp_names)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
return mod_graph_def