在Tensorflow中,我的模型基于预先训练的模型,我在预训练模型中添加了一些变量并删除了一些变量。当我从检查点文件恢复变量时,我必须明确指定我添加到图表中的所有需要排除的变量。例如,我做了
exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
有更简单的方法吗?也就是说,只要变量不在检查点中,就不要尝试恢复。
答案 0 :(得分:1)
您唯一能做的就是首先使用与检查点相同的模型,然后将检查点值恢复到同一模型。恢复同一模型的变量后,您可以添加新图层,删除现有图层或更改图层的权重。
但有一点很重要,你需要小心。添加新图层后,您需要初始化它们。如果您使用tf.global_variables_initializer()
,则会丢失重新加载的图层的值。所以你应该只初始化未初始化的权重,你可以使用以下函数。
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
# for i in not_initialized_vars: # only for testing
# print(i.name)
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
答案 1 :(得分:1)
您应该首先找出所有有用的变量(也就是图中的变量),然后从检查点添加两者的交集的联合集,而不是从它中添加所有变量。
variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir)))
然后在定义这样的保护程序后恢复它:
temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
答案 2 :(得分:-1)
这是更完整的答案,适用于非分布式设置:
from tensorflow.contrib.framework.python.framework import checkpoint_utils
slim = tf.contrib.slim
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
check_var_list = checkpoint_utils.list_variables(checkpoint_path)
check_var_list = [x[0] for x in check_var_list]
check_var_set = set(check_var_list)
vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
return vars_in_checkpoint, vars_not_in_checkpoint
def create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint):
model_ready_for_local_init_op = tf.report_uninitialized_variables(var_list = vars_in_checkpoint)
model_init_vars_not_in_checkpoint = tf.variables_initializer(vars_not_in_checkpoint)
restoration_saver = tf.train.Saver(vars_in_checkpoint)
eg_scaffold = tf.train.Scaffold(saver=restoration_saver,
ready_for_local_init_op = model_ready_for_local_init_op,
local_init_op = model_init_vars_not_in_checkpoint)
return eg_scaffold
all_vars = slim.get_variables()
ckpoint_file = tf.train.latest_checkpoint(output_chkpt_dir)
vars_in_checkpoint, vars_not_in_checkpoint = scan_checkpoint_for_vars(ckpoint_file, all_vars)
is_checkpoint_complete = len(vars_not_in_checkpoint) == 0
# Create session that can handle current checkpoint
if (is_checkpoint_complete):
# Checkpoint is full - all variables can be found there
print('Using normal session')
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
save_checkpoint_secs = save_checkpoint_secs,
save_summaries_secs = save_summaries_secs)
else:
# Checkpoint is partial - some variables need to be initialized
print('Using easy going session')
eg_scaffold = create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint)
# Save all variables to next checkpoint
saver = tf.train.Saver()
hooks = [tf.train.CheckpointSaverHook(checkpoint_dir = output_chkpt_dir,
save_secs = save_checkpoint_secs,
saver = saver)]
# Such session is a little slower during the first iteration
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
scaffold = eg_scaffold,
hooks = hooks,
save_summaries_secs = save_summaries_secs,
save_checkpoint_secs = None)
with sess:
.....