在其他图表中恢复tf变量

时间:2019-04-03 12:09:16

标签: python tensorflow

我想在其他模型的另一个可分离卷积中使用经过预训练的separable convolution(这是更大模块的一部分)。
在训练有素的模块中,我尝试过

with tf.variable_scope('sep_conv_ker' + str(input_shape[-1])):
            sep_conv2d = tf.reshape(
            tf.layers.separable_conv2d(inputs_flatten,input_shape[-1] , 
            [1,input_shape[-2]]
            trainable=trainable),
            [inputs_flatten.shape[0],1,input_shape[-1],INNER_LAYER_WIDTH]) 

        all_variables = tf.trainable_variables()
        scope1_variables = tf.contrib.framework.filter_variables(all_variables, include_patterns=['sep_conv_ker'])
        sep_conv_weights_saver = tf.train.Saver(scope1_variables, sharded=True, max_to_keep=20)

sess.run

sep_conv_weights_saver.save(sess,os.path.join(LOG_DIR + MODEL_SPEC_LOG_DIR,
                                                              "init_weights",MODEL_SPEC_SUFFIX + 'epoch_' + str(epoch) + '.ckpt'))

但是我无法理解何时以及如何将权重加载到另一个模块中的separable convolution上,它具有不同的名称和不同的范围,
此外,当我使用已定义的tf.layer时,是否意味着我需要访问新图中的每个权重并对其进行分配?

我当前的解决方案不起作用,我认为权重在赋值后就已经初始化了
此外,仅加载几个权重来加载整个新图似乎很奇怪,是吗?是吗?

        ###IN THE OLD GRAPH###
        all_variables = tf.trainable_variables()
        scope1_variables = tf.contrib.framework.filter_variables(all_variables, include_patterns=['sep_conv_ker'])
        vars = dict((var.op.name.split("/")[-1] + str(idx), var) for idx,var in enumerate(scope1_variables))
        sep_conv_weights_saver = tf.train.Saver(vars, sharded=True, max_to_keep=20)

在新图中,该函数基本上从旧图中获取变量并进行分配,加载meta_graph是多余的

def load_pretrained(sess):
    sep_conv2d_vars = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if ("sep_conv_ker" in var.op.name)]
    var_dict = dict((var.op.name.split("/")[-1] + str(idx), var) for idx, var in enumerate(sep_conv2d_vars))
    new_saver = tf.train.import_meta_graph(
        tf.train.latest_checkpoint('log/train/sep_conv_ker/global_neighbors40/init_weights') + '.meta')
    # saver = tf.train.Saver(var_list=var_dict)
    new_saver.restore(sess,
                      tf.train.latest_checkpoint('log/train/sep_conv_ker/global_neighbors40/init_weights'))

    graph = tf.get_default_graph()
    sep_conv2d_trained = dict(("".join(var.op.name.split("/")[-2:]),var) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if ("sep_conv_ker_init" in var.op.name))
    for var in sep_conv2d_vars:
        tf.assign(var,sep_conv2d_trained["".join(var.op.name.split("/")[-2:])])

1 个答案:

答案 0 :(得分:1)

您需要确保变量在变量文件中以及加载变量的图形中具有相同的变量。您可以编写一个脚本来转换变量名称。

  1. 使用tf.contrib.framework.list_variables(ckpt),您可以找出检查点中具有什么形状的变量,并使用新名称创建各自的变量(我相信,您可以编写一个会固定名称的正则表达式)并纠正形状。
  2. 然后,您使用tf.contrib.framework.load_checkpoint(ckpt)加载原始变量,并分配操作tf.assign(var, loaded),该操作将为变量分配具有已保存值的新名称。
  3. 在会话中运行分配操作。
  4. 保存新变量。

最小示例:

原始模型(范围“回归”中的变量):

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 3]) 
regression = tf.layers.dense(x, 1, name="regression")

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.trainable_variables())

saver.save(session, './model')

重命名脚本:

import tensorflow as tf

assign_ops = []
reader = tf.contrib.framework.load_checkpoint("./model")
for name, shape in tf.contrib.framework.list_variables("./model"):
    new_name = name.replace("regression/", "foo/bar/")
    new_var = tf.get_variable(new_name, shape)
    assign_ops.append(tf.assign(new_var, reader.get_tensor(name)))

session = tf.Session()
saver = tf.train.Saver(tf.trainable_variables())

session.run(assign_ops)
saver.save(session, './model-renamed')

加载重命名变量的模型(分数“ foo / bar”中相同的变量):

import tensorflow as tf

with tf.variable_scope("foo"):
    x = tf.placeholder(tf.float32, [None, 3]) 
    regression = tf.layers.dense(x, 1, name="bar")

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.trainable_variables())

saver.restore(session, './model-renamed')