CNN无法正确学习

时间:2020-04-05 22:06:52

标签: deep-learning regression pytorch conv-neural-network multiclass-classification

我有一个包含500张植物图像的小型数据集,我必须为[1,10]范围内的单个图像预测一个数字。数字之间存在顺序关系(10> 9> ...> 1)。这个问题类似于基于单张照片的年龄估计。

我尝试使用Resnet18,Resnet34和VGG16进行回归。他们都没有给出很好的结果。

有趣的一点是,当我绘制一些图像的热图时,它表明模型选择了错误的点来预测答案。就像,如果我假设要根据面部照片预测年龄,那么cnn赋予背景的价值要大于实际面孔的价值。

我也尝试了其他方法,例如分类和学习排名,但是当我进行热图绘制时,会发生相同的事情。在这些方法中,我得到的最佳准确性是使用分类的30%和使用学习排名的35%。

我使用带有预先训练的Fastai实现的回归和分类方法。在学习排名方法时,我使用了https://github.com/Raschka-research-group/coral-cnn。我做了一些改动,以便能够使用预训练的模型。

另一个重要的一点是数据集是不平衡的。数据集的80%对应于6至10类。

是否有人有任何改进建议或我可以尝试的其他方法?

编辑: 我的数据扩充看起来像这样:

transforms.Compose([
                  transforms.Resize(256), transforms.CenterCrop(224),
                  transforms.RandomHorizontalFlip(p=0.5),
                  transforms.ColorJitter(brightness=0.15), 
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.299, 0.224, 0.225])
])

1 个答案:

答案 0 :(得分:0)

您可以尝试扩充数据集以获得更多数据(例如随机裁剪,旋转等),并确保对数据进行规范化。对于班级不平衡问题,您可以尝试使用PyTorch的{​​{1}}:

WeightedRandomSampler

您应该可以轻松地将其应用于10个类的案例,希望这可以解决您的问题!