从TensorFlow图中擦除丢失操作

时间:2016-11-01 11:11:08

标签: arm tensorflow

我有一个训练有素的冻结图,我试图在ARM设备上运行。基本上,我使用的是contrib / pi_examples / label_image,但使用的是我的网络而不是Inception。我的网络经过辍学培训,现在给我带来了麻烦:

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs.  Registered kernels:
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_INT32]
  device='GPU'; T in [DT_STRING]
  device='GPU'; T in [DT_BOOL]
  device='GPU'; T in [DT_INT32]
  device='GPU'; T in [DT_FLOAT]

 [[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]

我能看到的一个解决方案是构建包含相应操作的TF静态库。另一方面,从网络中消除丢失操作以使其更简单和更快可能是更好的主意。有没有办法做到这一点?

感谢。

2 个答案:

答案 0 :(得分:5)

#!/usr/bin/env python2

import argparse

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2

def print_graph(input_graph):
    for node in input_graph.node:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-graph', action='store', dest='input_graph')
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
    parser.add_argument('--output-graph', action='store', dest='output_graph')
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)

    args = parser.parse_args()

    input_graph = args.input_graph
    input_binary = args.input_binary
    output_graph = args.output_graph
    output_binary = args.output_binary

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)

    print "Before:"
    print_graph(input_graph_def)
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
    print "After:"
    print_graph(output_graph_def)

    if output_binary:
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        with tf.gfile.GFile(output_graph, "w") as f:
            f.write(text_format.MessageToString(output_graph_def))
    print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == "__main__":
    main()

答案 1 :(得分:3)

这是一个更通用的解决方案:

for node in temp_graph_def.node:
    for idx, i in enumerate(node.input):
        input_clean = node_name_from_input(i)
        if input_clean.endswith('/cond/Merge') and input_clean.split('/')[-3].startswith('dropout'):
            identity = node_from_map(input_node_map, i).input[0]
            assert identity.split('/')[-1] == 'Identity'
            parent = node_from_map(input_node_map, node_from_map(input_node_map, identity).input[0])
            pred_id = parent.input[1]
            assert pred_id.split('/')[-1] == 'pred_id'            
            good = parent.input[0]
            node.input[idx] = good