CNN模型在达到50%的准确性后就过度拟合了数据

时间:2020-09-20 21:11:33

标签: python machine-learning deep-learning pytorch conv-neural-network

我正在尝试根据EEG连接组数据识别3种(类)精神状态。数据的形状为99x1x34x34x50x130(最初为图形数据,但现在表示为矩阵),分别表示[对象,通道,高度,宽度,频率,时间序列]。为了进行这项研究,只能输入1x34x34的连接套组数据图像。根据先前的研究,发现Alpha波段(8-1 hz)提供的信息最多,因此数据集的范围缩小到99x1x34x34x4x130。对诸如SVM之类的以往机器学习技术的测试集准确性达到了约75%的测试准确性。因此,目标是在给定相同数据(1x34x34)的情况下实现更高的准确性。由于我的数据非常有限,用于训练的1-66和用于测试的66-99(固定比率,并且具有1/3的类别分布),所以我考虑沿时间序列轴(第6轴)拆分数据,然后平均数据形状为1x34x34(例如1x34x34x4x10,其中10是时间序列的随机样本)。这给了我约1500个训练样本和33个测试样本(测试是固定的,班级分布是1/3)。

型号:

SimpleCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (drop1): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=9248, out_features=128, bias=True)
  (drop2): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=128, out_features=3, bias=True)
)
CrossEntropyLoss()
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 5e-06
    weight_decay: 0.0001
)

结果: 训练集可以通过足够的迭代来达到100%的精度,但是要付出测试集精度的代价。经过大约20至50个测试时期,模型开始过拟合训练集,并且测试集的准确性开始降低(与损失相同)。

enter image description here

我尝试过的方法: 我尝试过调整超参数:lr = .001-000001,权重衰减= 0.0001-0.00001。训练到1000个时期(少于100个时期的无用BC过度拟合)。我还尝试通过添加添加fc层以及在CNN层中形成8-64的不同数量的通道来增加/减少模型复杂性。我还尝试添加更多的CNN层,并且该模型在测试集上的平均精度大约为45%时差一些。我尝试每10个时段手动安排学习率,结果是一样的。体重减轻似乎对结果没有太大影响,将其从0.1-0.000001更改为

根据以前的测试,我有一个模型可以在测试和训练集上均达到60%的准确率。但是,当我尝试对其进行重新培训时,两套设备(培训和测试)的acc都立即降低到40左右,这没有任何意义。我曾尝试将学习率从0.01更改为0.00000001,并为此尝试了权重衰减。

从训练模型和图表开始,模型似乎不知道在最初的5-10个时期内在做什么,然后开始快速学习,两组都以〜50%-60%的速度递增。这就是模型开始过度拟合的地方,在那里,模型的acc在训练集上的acc增加到100%,测试集的acc下降到33%,这相当于猜测。

有什么提示吗?

编辑:

测试集的模型输出彼此非常接近。

0.33960407972335815, 0.311821848154068, 0.34857410192489624

每个图像的预测之间的整个测试集的平均标准偏差为(softmax):

0.017695341517654846

但是,训练集的平均std为.22,所以...

F1得分:

Micro Average: 0.6060606060606061
Macro Average: 0.5810185185185186
Weighted Average: 0.5810185185185186
Scores for each class: 0.6875 0.5 0.55555556

这是一个混淆矩阵: enter image description here

1 个答案:

答案 0 :(得分:3)

我有一些建议,我会尝试什么,也许您已经做到了:

  • 增加辍学的可能性,这可以减少过度拟合的情况,
  • 我没有看到或错过了它,但是如果您不这样做,请随机洗所有样本,
  • 没有太多数据,您是否考虑过使用其他NN生成得分最低的类的更多数据?我不确定是否是这种情况,但即使随机旋转,缩放图像也可以产生更多训练示例,
  • 您可以采用的另一种方法(如果尚未完成的话),通过另一种流行的CNN网络使用转移学习,并查看其工作情况,然后可以进行一些比较,无论您的体系结构是否有问题或缺乏例子:) 我知道这些只是建议,但也许,如果您没有尝试其中的一些建议,它们将使您更接近解决方案。 祝你好运!