如何使用PyTorch将数据矩阵作为标签分配给数据集中的每个输入图像?

时间:2019-02-14 20:09:05

标签: deep-learning dataset pytorch

我想在PyTorch中训练卷积神经网络(CNN)以预测与输入图像有关的频谱数据。与其给每个图像(狗,猫,汽车,飞机等)分配一个标签,不如给每个图像分配标签矩阵(每个频率一个标签)。在PyTorch中,如何为数据集中的每个输入图像分配数据矩阵作为标签?我一直在尝试使用ImageFolder进行此操作。谢谢!

1 个答案:

答案 0 :(得分:0)

ImageFolder数据集期望一种仅处理一个类标签的特定结构,因此可能不适合您的需求。

如果您查看DatasetFolder Class,它是ImageFolder的父类,请注意__getitem__方法,该方法获取一个索引并返回样本和目标。您可以做的一件事是更改数据加载器的samples属性,使其返回您自己的路径并为每个索引标记元组。

或者,您可以创建自己的数据集类,该数据集类继承自torch.utils.data.Dataset,并具有__getitem____len__方法。