如何仅在Tensorflow中的检查点中恢复变量?

时间:2017-12-09 17:25:09

标签: tensorflow

在Tensorflow中,我的模型基于预先训练的模型,我在预训练模型中添加了一些变量并删除了一些变量。当我从检查点文件恢复变量时,我必须明确指定我添加到图表中的所有需要​​排除的变量。例如,我做了

exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)

有更简单的方法吗?也就是说,只要变量不在检查点中,就不要尝试恢复。

3 个答案:

答案 0 :(得分:1)

您唯一能做的就是首先使用与检查点相同的模型,然后将检查点值恢复到同一模型。恢复同一模型的变量后,您可以添加新图层,删除现有图层或更改图层的权重。

但有一点很重要,你需要小心。添加新图层后,您需要初始化它们。如果您使用tf.global_variables_initializer(),则会丢失重新加载的图层的值。所以你应该只初始化未初始化的权重,你可以使用以下函数。

def initialize_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    # for i in not_initialized_vars: # only for testing
    #    print(i.name)

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))

答案 1 :(得分:1)

您应该首先找出所有有用的变量(也就是图中的变量),然后从检查点添加两者的交集的联合集,而不是从它中添加所有变量。

variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir))) 

然后在定义这样的保护程序后恢复它:

temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)

答案 2 :(得分:-1)

这是更完整的答案,适用于非分布式设置:

from tensorflow.contrib.framework.python.framework import checkpoint_utils
slim = tf.contrib.slim


def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
    check_var_list = checkpoint_utils.list_variables(checkpoint_path)
    check_var_list = [x[0] for x in check_var_list]
    check_var_set = set(check_var_list)
    vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
    vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
    return vars_in_checkpoint, vars_not_in_checkpoint


def create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint):
    model_ready_for_local_init_op = tf.report_uninitialized_variables(var_list = vars_in_checkpoint)
    model_init_vars_not_in_checkpoint = tf.variables_initializer(vars_not_in_checkpoint)

    restoration_saver = tf.train.Saver(vars_in_checkpoint)
    eg_scaffold = tf.train.Scaffold(saver=restoration_saver,
                                    ready_for_local_init_op = model_ready_for_local_init_op,
                                    local_init_op = model_init_vars_not_in_checkpoint)
    return eg_scaffold


all_vars = slim.get_variables()
ckpoint_file = tf.train.latest_checkpoint(output_chkpt_dir)
vars_in_checkpoint, vars_not_in_checkpoint = scan_checkpoint_for_vars(ckpoint_file, all_vars)
is_checkpoint_complete = len(vars_not_in_checkpoint) == 0

# Create session that can handle current checkpoint
if (is_checkpoint_complete):
    # Checkpoint is full - all variables can be found there
    print('Using normal session')
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
                                             save_checkpoint_secs = save_checkpoint_secs,
                                             save_summaries_secs = save_summaries_secs)
else:
    # Checkpoint is partial - some variables need to be initialized
    print('Using easy going session')
    eg_scaffold =  create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint)
    # Save all variables to next checkpoint
    saver = tf.train.Saver()
    hooks = [tf.train.CheckpointSaverHook(checkpoint_dir = output_chkpt_dir,
                                          save_secs = save_checkpoint_secs,
                                          saver = saver)]
    # Such session is a little slower during the first iteration
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
                                             scaffold = eg_scaffold,
                                             hooks = hooks,
                                             save_summaries_secs = save_summaries_secs,
                                             save_checkpoint_secs = None)

with sess:
    .....