如何使用Tensorflow中的CheckpointReader恢复变量

时间:2016-08-25 05:51:35

标签: python-2.7 tensorflow

如果当前模型中有相同的变量名,我试图从检查点文件中恢复一些变量 我发现Tensorfow Github

有一些方法

所以我想做的是使用has_tensor("variable.name")检查检查点文件中的变量名称,如下所示,

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    print v.name
    if reader.has_tensor(v.name):
        print 'has tensor'
...

但我发现v.name同时返回变量namecolon+number。例如,我有变量名W_ob_o,然后v.name返回W_o:0, b_o:0

但是reader.has_tensor()要求namecolonnumberW_o, b_o

我的问题是:如何删除变量名末尾的colonnumber以便读取变量?
有没有更好的方法来恢复这些变量?

3 个答案:

答案 0 :(得分:6)

您可以使用string.split()来获取张量名称:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    print tensor_name
    if reader.has_tensor(tensor_name):
        print 'has tensor'
...

接下来,让我用一个示例来说明如何从.cpkt文件中恢复每个可能的变量。首先,让我在v2中保存v3tmp.ckpt

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')

saver = tf.train.Saver({'v2': v2, 'v3': v3})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'tmp.ckpt')

我将如何恢复tmp.ckpt中显示的每个变量(属于新图表):

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros([1]), name='v1')
    v2 = tf.Variable(tf.zeros([1]), name='v2')

    reader = tf.train.NewCheckpointReader('tmp.ckpt')
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    saver = tf.train.Saver(restore_dict)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess, 'tmp.ckpt')
        print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]

此外,您可能希望确保形状和dtypes匹配。

答案 1 :(得分:1)

tf.train.NewCheckpointReader是一个创建CheckpointReader对象的漂亮方法。 CheckpointReader有几个非常有用的方法。与您的问题最相关的方法是get_variable_to_shape_map()。

  • get_variable_to_shape_map()提供了一个包含变量名称和形状的字典:



saved_shapes = reader.get_variable_to_shape_map()
print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']




请查看下面的快速教程: Loading Variables from Existing Checkpoints

答案 2 :(得分:0)

简单答案:

reader = tf.train.NewCheckpointReader(checkpoint_file)

variable1 = reader.get_tensor('layer_name1/layer_type_name')
variable2 = reader.get_tensor('layer_name2/layer_type_name')

现在,在修改这些变量之后,您可以将其分配回来。

layer_name1_var.set_weights([variable1, variable2])