在TensorFlow中重新训练冻结的* .pb模型

时间:2018-10-31 13:51:04

标签: tensorflow

  

如何导入冷冻的原生动物以对其进行重新训练?

我在网上找到的所有方法都需要检查点。有没有办法读取protobuf,以便将内核常数和偏差常数转换为变量?


编辑1 : 这类似于以下问题:How to retrain model in graph (.pb)?

我查看了DeepSpeech,该问题的答案中建议使用它。他们似乎有initialize_from_frozen_model的{​​{3}}。我找不到原因。


编辑2 :我尝试创建一个新的GraphDef对象,在其中用变量替换了内核和偏差:

probable_variables = [...] # kernels and biases of Conv2D and MatMul

new_graph_def = tf.GraphDef()

with tf.Session(graph=graph) as sess:
    for n in sess.graph_def.node:

        if n.name in probable_variables:
            # create variable op
            nn = new_graph_def.node.add()
            nn.name = n.name
            nn.op = 'VariableV2'
            nn.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
            nn.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=shape))

        else:
            nn = new_model.node.add()
            nn.CopyFrom(n)

不确定我走的路是否正确。不知道如何在trainable=True对象中设置NodeDef

5 个答案:

答案 0 :(得分:2)

实际上,您提供的代码片段的方向正确:)


步骤1:获取以前可训练的变量的名称

最棘手的部分是获取以前可训练的变量的名称。希望该模型是通过一些高级框架(例如kerastf.slim)创建的-他们将变量很好地包装在诸如conv2d_1/kerneldense_1/biasbatch_normalization/gamma之类的变量中等

如果不确定,最有用的就是可视化图形...

# read graph definition
with tf.gfile.GFile('frozen.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# now build the graph in the memory and visualize it
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="prefix")
    writer = tf.summary.FileWriter('out', graph)
    writer.close()

...使用张量板:

$ tensorboard --logdir out/

亲自观察一下图形的外观和名称。


步骤2:用变量替换常量(有趣的部分:D)

您需要的只是一个名为tf.contrib.graph_editor的神奇库。现在,假设您已经将先前可训练的操作的名称(以前是变量,但现在是Const)存储在probable_variables中(就像您在 Edit 2中一样)。

注意:请记住opstensorsvariables之间的区别。操作是图形的元素,张量是包含操作结果的缓冲区,变量wrappers位于张量周围,具有3个操作:assign(在初始化变量时调用),{{ 1}}(由其他操作调用,例如read)和ref tensor(保存值)。

注释2:conv2d只能在会话外运行 –您不能在线进行任何图形修改!

graph_editor

PS: 此代码未经测试;但是,最近我一直在使用import numpy as np import tensorflow.contrib.graph_editor as ge # load the graphdef into memory, just as in Step 1 graph = load_graph('frozen.pb') # create a variable for each constant, beware the naming const_var_name_pairs = [] for name in probable_variables: var_shape = graph.get_tensor_by_name('{}:0'.format(name)).get_shape() var_name = '{}_a'.format(name) var = tf.get_variable(name=var_name, shape=var_shape, dtype='float32') const_var_name_pairs.append((name, var_name)) # from now we're going to work with GraphDef name_to_op = dict([(n.name, n) for n in graph.as_graph_def().node]) # magic: now we swap the outputs of const and created variable for const_name, var_name in const_var_name_pairs: const_op = name_to_op[const_name] var_reader_op = name_to_op[var_name + '/read'] ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op)) # Now we can safely create a session and copy the values sess = tf.Session(graph=graph) for const_name, var_name in const_var_name_pairs: ts = graph.get_tensor_by_name('{}:0'.format(const_name)) var = tf.get_variable(var_name) var.load(ts.eval(sess)) # All done! Now you can make sure everything is correct by visualizing # and calculate outputs for some inputs. 并经常进行网络手术,因此我认为它应该基本上是正确的:)

答案 1 :(得分:2)

我已经用经过测试的代码验证了@FalconUA的解决方案。需要稍作修改(特别是,我在initializer中使用了get_variable选项来正确地初始化变量)。在这里!

假设冻结的模型存储在frozen_graph.pb中:

probable_variables = [...] # kernels and biases of Conv2D and MatMul
tf_graph = load_pb('frozen_graph.pb')

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in probable_variables:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

注意:如果保存转换后的模型并在Tensorboard或Netron中查看,您将看到Variables取代了Constants。您还将看到一堆悬空的常量,可以有选择地删除它们。

我已经验证了冻结版本和未冻结版本之间的权重值是相同的。

这是load_pb函数:

import tensorflow as tf
# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

答案 2 :(得分:1)

感谢@FalconUA和@Max wu。 我添加了一种快速获取变量名称的方法。

import tensorflow as tf


# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


tf_graph = load_pb('mobilenet_v1_1.0_224_frozen_ccl.pb')
variables = [op.name for op in tf_graph.get_operations() if op.type == "Const"]
print(variables)

答案 3 :(得分:0)

[代码] def protobuf_to_checkpoint_conversion(pb_model,ckpt_dir):

graph = tf.Graph()
with graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(pb_model, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def,name='')

image_tensor = graph.get_tensor_by_name('image_tensor:0')
dummy = np.random.random((1, 512, 512, 3))

with graph.as_default():
    config = tf.ConfigProto()
    with tf.Session(graph=graph, config=config) as sess:
        constant_ops = [op for op in graph.get_operations() if op.type == "Const"]
        vars_dict = {}
        ass = []
        for constant_op in constant_ops:
            name = constant_op.name
            const = constant_op.outputs[0]
            shape = const.shape
            var = tf.get_variable(name, shape, dtype=const.dtype, initializer=tf.zeros_initializer())
            vars_dict[name] = var

        print('INFO:Initializing variables')
        init = tf.global_variables_initializer()
        sess.run(init)

        print('INFO: Loading vars')
        for constant_op in tqdm(constant_ops):
            name = constant_op.name
            if 'FeatureExtractor' in name or 'BoxPredictor' in name:
                const = constant_op.outputs[0]
                shape = const.shape
                var = vars_dict[name]
                var.load(sess.run(const, feed_dict={image_tensor:dummy}), sess)

        saver = tf.train.Saver(var_list=vars_dict)
        ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
        saver.save(sess, ckpt_path)
return graph, vars_dict

[/ code]

参考:https://github.com/yeephycho/tensorflow-face-detection/issues/42#issuecomment-455325984

答案 4 :(得分:0)

def protobuf_to_checkpoint_conversion(pb_model,ckpt_dir):

graph = tf.Graph()
with graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(pb_model, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def,name='')

image_tensor = graph.get_tensor_by_name('image_tensor:0')
dummy = np.random.random((1, 512, 512, 3))

with graph.as_default():
    config = tf.ConfigProto()
    with tf.Session(graph=graph, config=config) as sess:
        constant_ops = [op for op in graph.get_operations() if op.type == "Const"]
        vars_dict = {}
        ass = []
        for constant_op in constant_ops:
            name = constant_op.name
            const = constant_op.outputs[0]
            shape = const.shape
            var = tf.get_variable(name, shape, dtype=const.dtype, initializer=tf.zeros_initializer())
            vars_dict[name] = var

        print('INFO:Initializing variables')
        init = tf.global_variables_initializer()
        sess.run(init)

        print('INFO: Loading vars')
        for constant_op in tqdm(constant_ops):
            name = constant_op.name
            if 'FeatureExtractor' in name or 'BoxPredictor' in name:
                const = constant_op.outputs[0]
                shape = const.shape
                var = vars_dict[name]
                var.load(sess.run(const, feed_dict={image_tensor:dummy}), sess)

        saver = tf.train.Saver(var_list=vars_dict)
        ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
        saver.save(sess, ckpt_path)

参考:https://github.com/yeephycho/tensorflow-face-detection/issues/42#issuecomment-455325984