TensorFlow:有没有办法将冻结图转换为检查点模型?

时间:2017-07-24 08:00:16

标签: python tensorflow

可以将检查点模型转换为冻结图(.ckpt文件到.pb文件)。但是,有没有一种将pb文件再次转换为检查点文件的反向方法?

我想它需要将常量转换回变量 - 有没有办法将正确的常量识别为变量并将它们恢复为检查点模型?

目前支持将变量转换为常量:https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

但不是相反。

此处提出了类似的问题:Tensorflow: Convert constant tensor from pre-trained Vgg model to variable

但该解决方案依赖于使用ckpt模型来恢复权重变量。有没有办法从PB文件而不是检查点文件中恢复权重变量?这对于重量修剪很有用。

3 个答案:

答案 0 :(得分:1)

如果您具有构建网络的源代码,则可以相对容易地完成操作,因为冻结图方法未更改卷积/完全连接的名称,因此您基本上可以研究该图并匹配常量操作它们的变量匹配,然后只需将常量值加载到变量中即可。

如果您没有构建网络的代码,仍然可以完成,但是这样做不是直接的。

例如,您可以搜索图形中的所有节点并查找“常量”类型的操作,然后在找到“常量”类型的所有操作之后,可以查看该操作是否已连接到卷积/完全连接,例如。(或您可以只转换它所依赖的所有常量。)

找到要转换为变量的常量后,可以向包含常量值的图形添加变量,然后使用Tensorflow graph editor重新连接(使用reroute_ts方法) )之间的const操作。

完成后,您可以保存图形,并在再次加载图形时拥有变量(但请注意,常量将仍然保留在图形中,但是可以通过graph-transform对其进行优化。工具)

答案 1 :(得分:1)

  

如果您具有构建网络的源代码,则可以相对容易地完成操作,因为冻结图方法未更改卷积/完全连接的名称,因此您基本上可以研究该图并匹配常量操作使其变量匹配,并仅将常量值加载到变量中。 -由Almog David

感谢@ Almog David的出色回答;我正面临着与

完全相同的情况
  • 我有frozen_inference_graph.pb,但没有检查站;
  • 我有产生frozen_inference_graph.pb的源代码,但我不知道参数。

及以下是解决难题的三个步骤。

1。从frozen_inference_graph.pb

获取成对的节点名称和值
import tensorflow as tf
from tensorflow.python.framework import tensor_util

def get_node_values(old_graph_path):
    old_graph = tf.Graph()
    with old_graph.as_default():
        old_graph_def = tf.GraphDef()
        with tf.gfile.GFile(old_graph_path, "rb") as fid:
            serialized_graph = fid.read()
            old_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(old_graph_def, name='')

    old_sess = tf.Session(graph=old_graph)

    # get all the nodes from the graph def
    nodes = old_sess.graph.as_graph_def().node

    value_dict = {}
    for node in nodes:
        value = node.attr['value'].tensor
        try:
            # get name and value (numpy array) from tensor 
            value_dict[node.name] = tensor_util.MakeNdarray(value) 
        except:
            # some tensor doesn't have value; for example np.squeeze
            # just ignore it 
            pass
    return value_dict

value_dict = get_node_values("frozen_inference_graph.pb")

2。使用现有代码创建新图形;调整模型参数,直到新图中的所有节点都出现在value_dict

new_graph = tf.Graph()
with new_graph.as_default():
    tf.create_global_step()
    #existing code 
    # ...
    # ...
    # ...

    model_variables = tf.model_variables()
    unseen_variables = set(model_variable.name[:-2] for model_variable in model_variables) - set(value_dict.keys())
    print  ("\n".join(sorted(list(unseen_variables))))

3。将值分配给变量并保存到检查点(或保存到图形)

new_graph_path = "model.ckpt"
saver = tf.train.Saver(model_variables)

assign_ops = []
for variable in model_variables:
    print ("Assigning", variable.name[:-2])
    # variable names have ":0" but constant names doesn't have.
    value = value_dict[variable.name[:-2]]
    assign_ops.append(variable.assign(value))

sess =session.Session(graph = new_graph)
sess.run(tf.global_variables_initializer())
sess.run(assign_ops)
saver.save(sess, new_graph_path+"model.ckpt")

这是我想解决这个问题的唯一方法。但是,它仍然存在一些缺点:如果重新加载模型检查点,则会发现(以及所有有用的变量)很多不需要的assign变量,例如Assign_700/value。这是不可避免的,而且看起来很丑。如果您有更好的建议,请随时发表评论。谢谢。

答案 2 :(得分:0)

有一种 方法,可通过图形编辑器在TensorFlow中将常量转换回可训练的变量。但是,您将需要指定要转换的节点,因为我不确定是否有办法以健壮的方式自动检测到这一点。

以下是步骤:

步骤1:加载冻结的图

我们将.pb文件加载到图形对象中。

import tensorflow as tf

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb('frozen_graph.pb')

步骤2:查找需要转换的常量

有两种方法可以列出图中节点的名称:

  • 使用this script进行打印
  • print([n.name for n in tf_graph.as_graph_def().node])

您要转换的节点可能按照“ Const”的名称命名。可以肯定的是,将图形加载到Netron中是个好主意,以查看哪些张量存储了可训练的权重。通常,可以安全地假设所有const节点都是变量。

确定了这些节点后,让我们将其名称存储在列表中:

to_convert = [...] # names of tensors to convert

步骤3:将常量转换为变量

运行此代码来转换您指定的常量。本质上,它为每个常量创建相应的变量,并使用GraphEditor从图中解开常量,然后将变量挂钩。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in to_convert:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

步骤4:将结果另存为.ckpt

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = tf.train.Saver().save(sess, 'model.ckpt')
        print("Model saved in path: %s" % save_path)

中提琴!此时,您应该完成操作:)我自己就能完成此工作,并验证了模型权重是否得到保留-唯一的区别是该图现在可以训练了。请让我知道是否有任何问题。