使用MonitoredTrainingSession

时间:2017-07-16 23:45:23

标签: python-3.x tensorflow deep-learning

我正在InceptionV3上进行5种花卉数据集的转移学习。除输出图层外,所有图层都被冻结。我的实现很大程度上基于Tensorflow的Cifar10教程,输入数据集的格式与Cifar10相同。

我添加了一个MonitoredTrainingSession(就像在教程中一样)来报告一定数量的步骤后的准确性和损失。下面是MonitoredTrainingSession的代码部分(几乎与教程相同):

class _LoggerHook(tf.train.SessionRunHook):

    def begin(self):
        self._step = -1
        self._start_time = time.time()
    def before_run(self,run_context):
        self._step+=1
        return tf.train.SessionRunArgs([loss,accuracy])

    def after_run(self,run_context,run_values):
        if self._step % LOG_FREQUENCY ==0:
            current_time = time.time()
            duration = current_time - self._start_time
            self._start_time = current_time

            loss_value = run_values.results[0]
            acc = run_values.results[1]

            examples_per_sec = LOG_FREQUENCY/duration
            sec_per_batch = duration / LOG_FREQUENCY

            format_str = ('%s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f sec/batch)')

            print(format_str %(datetime.now(),self._step,loss_value,acc,
                examples_per_sec,sec_per_batch))
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
if MODE == 'train':

    file_writer = tf.summary.FileWriter(LOGDIR,tf.get_default_graph())
    with tf.train.MonitoredTrainingSession(
            save_checkpoint_secs=70,
            checkpoint_dir=LOGDIR,
            hooks=[tf.train.StopAtStepHook(last_step=NUM_EPOCHS*NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()],
            config=config) as mon_sess:
        original_saver.restore(mon_sess,INCEPTION_V3_CHECKPOINT)
        print("Proceeding to training stage")

        while not mon_sess.should_stop():
            mon_sess.run(train_op,feed_dict={training:True})
            print('acc: %f' %mon_sess.run(accuracy,feed_dict={training:False}))
            print('loss: %f' %mon_sess.run(loss,feed_dict={training:False}))

当打印mon_sess.run(train_op...下的精确度和损失的两条线被删除时,after_run打印的损失和准确度,在它仅仅训练了20分钟之后,报告该模型表现非常好在训练集上,损失正在减少。即使是均线损失也报告了很好的结果。对于多个随机批次,它最终接近90%以上的准确度。

之后,培训课程报告了一段时间的高精度,我停止了训练课程,恢复了模型,并在同一训练集中随机批量运行。它表现不佳,只达到50%到85%的准确率。我确认它已正确恢复,因为它确实比具有未经训练的输出层的模型表现更好。

然后我从最后一个检查站重新开始训练。准确度最初很低,但在大约10次小批量运行后,准确度回到90%以上。然后我重复了这个过程,但是这次添加了两条线来评估训练操作后的损失和准确性。这两项评估报告称,该模型的问题趋于融合,表现不佳。虽然通过before_runafter_run,进行的评估现在只是偶尔显示出高准确度和低损失(结果跳了起来)。但仍有after_run有时会报告100%的准确性(我认为它不再一致的事实是因为after_runmon_sess.run(accuracy...)也会调用mon_sess.run(loss...)

为什么MonitoredTrainingSession报告的结果表明模型表现不佳呢? SessionRunArgs中的两个操作不是使用与train_op相同的小批量生成,表示在渐变更新之前批次的模型性能?

以下是我用于恢复和测试模型的代码(基于cifar10教程):

elif MODE == 'test':
    init = tf.global_variables_initializer()
    ckpt = tf.train.get_checkpoint_state(LOGDIR)
    if ckpt and ckpt.model_checkpoint_path:
        with tf.Session(config=config) as sess:
                init.run()
                saver = tf.train.Saver()
                print(ckpt.model_checkpoint_path)
                saver.restore(sess,ckpt.model_checkpoint_path)
                global_step = tf.contrib.framework.get_or_create_global_step()

                coord = tf.train.Coordinator()
                threads =[]
                try:
                    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,start=True))
                    print('model restored')
                    i =0
                    num_iter = 4*NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN/BATCH_SIZE
                    print(num_iter)
                    while not coord.should_stop() and i < num_iter:
                        print("loss: %.2f," %loss.eval(feed_dict={training:False}),end="")
                        print("acc: %.2f" %accuracy.eval(feed_dict={training:False}))
                        i+=1
                except Exception as e:
                    print(e)
                    coord.request_stop(e)
                coord.request_stop()
                coord.join(threads,stop_grace_period_secs=10)

更新

所以我能够解决这个问题。但是,我不确定它为什么会起作用。在初始模型的arg_scope中,我传入了一个is_training布尔占位符,用于批处理规范和初始使用的丢失。但是,当我删除占位符并将is_training关键字设置为true时,恢复模型时训练集的准确性非常高。这是之前表现不佳的同一型号检查站。当我训练它时,我总是将is_training占位符设置为true。在测试时将is_training设置为true意味着批处理Norm现在使用样本均值和方差。

为什么告诉Batch Norm现在使用样本平均值和样本标准偏差,就像在训练期间提高准确度一样?

这也意味着丢失层正在丢弃单位,并且在启用了丢失层的情况下,在训练集和测试集上测试期间模型的准确性更高。

更新2 我浏览了上面代码中arg_scope引用的tensorflow slim inceptionv3模型代码。我在Avg池8x8后删除了最终的丢失层,准确率保持在99%左右。但是,当我将is_training设置为仅为批量标准层的False时,精度会降低到70%左右。这是来自slim\nets\inception_v3.py的arg_scope和我的修改。

with variable_scope.variable_scope(
      scope, 'InceptionV3', [inputs, num_classes], reuse=reuse) as scope:
    with arg_scope(
        [layers_lib.batch_norm],is_training=False): #layers_lib.dropout], is_training=is_training):
      net, end_points = inception_v3_base(
          inputs,
          scope=scope,
          min_depth=min_depth,
          depth_multiplier=depth_multiplier)

我尝试了这一点,删除了丢失图层并保留了丢失图层,并将is_training=True传递到了dropout图层。

1 个答案:

答案 0 :(得分:2)

(总结来自dylan7在问题评论中的调试)

Batch norm依赖于变量来保存它规范化的摘要统计信息。只有is_training通过UPDATE_OPS集合为True时才会更新这些内容(请参阅batch_norm documentation)。如果这些更新操作没有运行(或者变量被覆盖),那么当is_training为False时,可能会有基于每个批处理丢失的瞬时“合理”统计数据(测试数据不是,也不应该是,用于通知batch_norm摘要统计信息)。