如何在模型微调时继续丢失/步骤信息

时间:2018-05-14 03:43:02

标签: tensorflow

我使用tf.train.Supervisor来微调tensorflow中的基本模型。以下是相关代码。

sv = tf.train.Supervisor(
     logdir=args.checkpoint_dir,
     save_summaries_secs=args.summary_interval,
     save_model_secs=args.checkpoint_interval,
     init_fn=load_initial_weights_insess) 

#using moving average to update loss and psnr
update_ma = ema.apply([loss, psnr])

def load_initial_weights_insess(sess):
    log.info("------------------------------------------")
    if len(initial_ckpt) <= 0:
        return
    log.info('Prepare to load initial weights from {}'.format(initial_ckpt))
    log.info("Variables to load initial weights are:")
    for v in initial_variables:
        log.info("----{}".format(v.name))
    initial_saver.restore(sess, initial_ckpt)
    log.info("Finished load initial weights")

代码可以正常运行。但它无法继续基本模型中的损失/步骤信息。也就是说,它总是启动丢失/步骤信息,而不是在微调处理中继续它们。从我的角度来看,损失的初始值(0)可能会导致错误的优化方向。

#Training log for the basic model, the following log is from the end step:
Step 88689 | loss = 0.0551 | psnr = 30.8 dB 

#Finetune log:
Step 0   | loss = 0.0    | psnr = 0.0 dB 
Step 52  | loss = 0.0012 | psnr = 4.1 dB 
Step 103 | loss = 0.0029 | psnr = 7.2 dB 

所以我的问题是如何在模拟模型时继续丢失/步骤信息。非常感谢提前。

2 个答案:

答案 0 :(得分:1)

初始化tf.train.Supervisor()时,可以传递global_step参数。您还可以确保将global_step变量添加到SAVEABLE_OBJECTS集合中,从而使用模型进行保存/恢复。这是在TensorFlow source

中收集保存和恢复对象的方法
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
        ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))

损失的问题有些不同,因为当你开始微调时,它没有任何价值。所以你可以做两种方法。一种是再次将其添加到SAVEABLE_OBJECTS集合中并使用模型保存/恢复它。另一个是使用tf.train.ExponentialMovingAverage()zero_debias参数进行设置(在代码中不可见)。这将确保损失的初始值为零不会使移动平均值向下偏移

答案 1 :(得分:0)

在预训练中,初始化“global_step”变量并在保存之前将其与初步优化器一起使用。 tf.train.Superviser然后可以通过 init 从初步培训会话中传递恢复的global_step,检查文档here