使用fastai图像增强会降低我的准确性

时间:2020-05-30 07:22:27

标签: python computer-vision pytorch classification fast-ai

我正在使用Food-101数据集进行分类。我正在尝试使用ResNet50将SOTA分数>> = 90%。最初,我的准确率是75%。当我包括标签平滑处理时,我跳到了81%,但是现在当我使用fastai的get_transforms时,我下降到了79%。我还注意到,即使我正在使用图像增强功能,但运行纪元时数据的尺寸并没有增加。

当前,tfms变量是我的扩展:

np.random.seed(42)

path = '/content/food-101/images/train'
file_parse = r'/([^/]+)_\d+\.(png|jpg|jpeg)$'
tfms = get_transforms(do_flip=True,flip_vert=True, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75)

data = ImageList.from_folder(path).split_by_rand_pct(valid_pct=0.2).label_from_re(pat=file_parse).transform(tfms, size=224).databunch()

top_1 = partial(top_k_accuracy, k=1)
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, top_1], loss_func = LabelSmoothingCrossEntropy(), callback_fns=ShowGraph)

learn.lr_find()
learn.recorder.plot(suggestion=True)

enter image description here

当我将get_transforms()应用于模型时,我的学习率也从1.02e-06略微提高到1.023e-06。

我又跑了5个纪元。由于标签平滑,精度低。这不是问题:

learn.fit_one_cycle(5, max_lr=slice(1.023e-06/5, 1.023e-06/15))
learn.save('stage-2')
epoch   train_loss  valid_loss  accuracy    top_k_accuracy  time
0       6.549849    5.393574    0.009835    0.009835        18:39
1       6.520850    5.373156    0.009835    0.009835        17:57
2       6.468420    5.362789    0.009901    0.009901        17:49
3       6.461528    5.363743    0.009769    0.009769        17:51
4       6.473554    5.364657    0.009901    0.009901        18:05

enter image description here

现在,我正在提高另一个学习率。跃升至1.03e-06。

learn.unfreeze()
learn.lr_find()
learn.recorder.plot(suggestion=True)

enter image description here

使用新的LR运行另外5个纪元并保存。在这一轮中,标签平滑似乎没有像之前那样具有相同的效果。该模型现在显示出比早期5个时期更好的准确性:

learn.fit_one_cycle(5, max_lr=slice(1e-03/5, 1e-03/15))
learn.save('stage-3')
epoch   train_loss  valid_loss  accuracy    top_k_accuracy  time
0       3.476312    2.645357    0.491683    0.491683        18:11
1       2.781384    2.276255    0.599670    0.599670        18:22
2       2.356208    1.974409    0.677426    0.677426        18:31
3       2.068619    1.836324    0.732409    0.732409        18:26
4       1.943876    1.789893    0.742310    0.742310        18:20

enter image description here

当我使用以下代码块检查混乱之处时:

interp.plot_confusion_matrix()
interp.most_confused(min_val=5)
interp.plot_multi_top_losses()
interp.plot_confusion_matrix(figsize=(20, 20), dpi=200)

我得到了有关验证集的输出:

15081 misclassified samples over 15150 samples in the validation set.

最后,当我使用测试集验证模型时,这就是我得到的:

path = '/content/food-101/images'
data_test = ImageList.from_folder(path).split_by_folder(train='train', valid='test').label_from_re(file_parse).transform(size=224).databunch()

learn.load('stage-3')
learn.validate(data_test.valid_dl)

[1.602944,张量(0.7987),张量(0.7987)]

0 个答案:

没有答案