Image Classification. Validation loss stuck during training with inception (v1)

时间:2017-08-04 13:13:49

标签: validation tensorflow deep-learning classification loss

I have built a small custom image classification training/val dataset with 4 classes. The training dataset has ~ 110.000 images. The validation dataset has ~ 6.000 images.

The problem I'm experiencing is that, during training, both training accuracy (measured as an average accuracy on the last training samples) and training loss improve, while validation accuracy and loss stay the same.

This only occurs when I use inception and resnet models, if I use an alexnet model on the same training and validation data, the validation loss and accuracy improve

In my experiments I am employing several convolutional architectures by importing them from tensorflow.contrib.slim.nets

The code is organized as follows:

...

images, labels = preprocessing(..., train=True)
val_images, val_labels = preprocessing(..., train=False)

...

# AlexNet model
with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
    logits, _ = alexnet.alexnet_v2(images, ..., is_training=True)
    tf.get_variable_scope().reuse_variables()
    val_logits, _ = alexnet.alexnet_v2(val_images, ..., is_training=False)

# Inception v1 model
with slim.arg_scope(inception_v1_arg_scope()):
    logits, _ = inception_v1(images, ..., is_training=True)
    val_logits, _ = inception_v1(val_images, ..., is_training=False, reuse=True)

loss = my_stuff.loss(logits, labels)
val_loss = my_stuff.loss(val_logits, val_labels)

training_accuracy_op = tf.nn.in_top_k(logits, labels, 1)
top_1_op = tf.nn.in_top_k(val_logits, val_labels, 1)
train_op = ...

...

Instead of using a separate eval script, I'm running the validation step at the end of each epoch and also, for debugging purposes, I'm running an early val step (before training) and I'm checking the training accuracies by averaging training predictions on the last x steps.

When I use the Inception v1 model (commenting out the alexnet one) the logger output is as follows after 1 epoch:

early Validation Step
precision @ 1 = 0.2440 val loss = 1.39
Starting epoch 0
step 50, loss = 1.38, training_acc = 0.3250
...
step 1000, loss = 0.58, training_acc = 0.6725
...
step 3550, loss = 0.45, training_acc = 0.8063
Validation Step
precision @ 1 = 0.2473 val loss = 1.39

As shown, training accuracy and loss improve a lot after one epoch, but the validation loss doesn't change at all. This has been tested at least 10 times, the result is always the same. I would understand if the validation loss was getting worse due to overfitting, but in this case it's not changing at all.

To rule out any problems with the validation data, I'm also presenting the results while training using the AlexNet implementation in slim. Training with the alexnet model produces the following output:

early Validation Step
precision @ 1 = 0.2448 val loss = 1.39
Starting epoch 0
step 50, loss = 1.39, training_acc = 0.2587
...
step 350, loss = 1.38, training_acc = 0.2919
...
step 850, loss = 1.28, training_acc = 0.3898
Validation Step
precision @ 1 = 0.4069 val loss = 1.25

Accuracy and validation loss, both in training and test data, correctly improve when using the alexnet model, and they keep improving in subsequent epochs.

I don't understand what may be the cause of the problem, and why it presents itself when using inception/resnet models, but not when training with alexnet.

Does anyone have ideas?

2 个答案:

答案 0 :(得分:0)

您似乎正在使用logits来计算验证损失;使用预测,它可能有所帮助。

val_logits, _ = inception_v1(val_images, ..., is_training=False, reuse=True)
val_logits = tf.nn.softmax(val_logits)

答案 1 :(得分:0)

通过论坛搜索,阅读各种线程并进行实验后,我找到了问题的根源。

使用一个基本上从另一个例子中回收的train_op就是问题,它与alexnet模型配合得很好,但由于缺少批量规范化更新,因此无法在其他模型上工作。

要解决此问题,我必须使用

optimizer = tf.train.GradientDescentOptimizer(0.005)
train_op = slim.learning.create_train_op(total_loss, optimizer)

train_op = tf.contrib.layers.optimize_loss(total_loss, global_step, .005, 'SGD')

这似乎照顾了正在进行的batchnorm更新。

由于移动平均线更新速度缓慢,问题仍然存在于短期训练中。

默认的slim arg_scope将衰减设置为0.9997,这是稳定的,但显然需要很多步骤才能收敛。使用相同的arg_scope但衰减设置为0.99或0.9确实有助于这个简短的训练场景。