如何恢复预训练模型以初始化参数

时间:2018-09-27 07:59:08

标签: python tensorflow

我已经下载了具有预训练模型的网络。我在网络中添加了一些层和参数,我想使用这个经过预训练的模型来初始化原始参数,并自己亲自初始化新添加的参数。我使用以下代码:

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "output/saver-test")
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

但是我遇到了错误:“在检查点中找不到关键的global_step”,此错误是因为我有一些在预训练模型中不存在的新参数。但是我该如何解决这个问题呢?而且,我想使用此代码“ sess.run(tf.global_variables_initializer())”来初始化新添加的参数,但是从预训练模型中提取的参数将被其覆盖吗?

1 个答案:

答案 0 :(得分:0)

发生这种情况是因为您的网络与加载的网络不完全匹配。 您可以使用类似这样的选择性检查点加载器:

  reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    restore_dict['my_new_var_scope/my_new_var'] = self.get_my_new_var_variable()

get_my_new_var_variable()类似于以下内容:

    def get_my_new_var_variable(self):
    with tf.variable_scope("my_new_var_scope",reuse=tf.AUTO_REUSE):
        my_new_var = tf.get_variable("my_new_var", dtype=tf.int32,initializer=tf.constant([23, 42]))
    return my_new_var

加载权重:

self.saver = tf.train.Saver(restore_dict)
    self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))

编辑:

请注意,为了避免覆盖已加载的变量,可以使用以下方法:

def initialize_uninitialized(sess):
  global_vars = tf.global_variables()
  is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
  not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
  if len(not_initialized_vars):
    sess.run(tf.variables_initializer(not_initialized_vars))

或者在加载变量之前简单地调用tf.global_variables_initializer()应该在这里起作用。