我已经下载了具有预训练模型的网络。我在网络中添加了一些层和参数,我想使用这个经过预训练的模型来初始化原始参数,并自己亲自初始化新添加的参数。我使用以下代码:
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "output/saver-test")
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
但是我遇到了错误:“在检查点中找不到关键的global_step”,此错误是因为我有一些在预训练模型中不存在的新参数。但是我该如何解决这个问题呢?而且,我想使用此代码“ sess.run(tf.global_variables_initializer())”来初始化新添加的参数,但是从预训练模型中提取的参数将被其覆盖吗?
答案 0 :(得分:0)
发生这种情况是因为您的网络与加载的网络不完全匹配。 您可以使用类似这样的选择性检查点加载器:
reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
restore_dict['my_new_var_scope/my_new_var'] = self.get_my_new_var_variable()
get_my_new_var_variable()类似于以下内容:
def get_my_new_var_variable(self):
with tf.variable_scope("my_new_var_scope",reuse=tf.AUTO_REUSE):
my_new_var = tf.get_variable("my_new_var", dtype=tf.int32,initializer=tf.constant([23, 42]))
return my_new_var
加载权重:
self.saver = tf.train.Saver(restore_dict)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
编辑:
请注意,为了避免覆盖已加载的变量,可以使用以下方法:
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
或者在加载变量之前简单地调用tf.global_variables_initializer()
应该在这里起作用。