我看过MNIST数据集的Pytorch源代码,但它似乎直接从二进制文件读取numpy数组。 我怎样才能像这样创建train_data和train_labels?我已经准备了带有标签的图片和txt。
我已经学会了如何读取图像并标记标签以及如何写get_item和len,真正令我感到困惑的是如何制作 train_data 和 train_labels ,即Torch.Tensor。我试图将它们排列成python列表并转换为torch.Tensor,但失败了:
for index in range(0,len(self.files)):
fn, label = self.files[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
train_data.append(img)
self.train_data = torch.tensor(train_data)
ValueError:只能将一个元素张量转换为Python标量
答案 0 :(得分:0)
有两种方法。首先,手册。 Torchvision.datasets声明以下内容:
数据集是torch.utils.data.Dataset的子类,即,它们实现了
__getitem__
和__len__
方法。因此,它们都可以传递给torch.utils.data.DataLoader,后者可以使用torch.multiprocessing工作程序并行加载多个样本。
因此,您只需实现自己的类即可扫描所有图像和标签,保留它们的路径列表(这样就不必将其保留在RAM中)并具有__getitem__
方法给定的索引i
读取第i个文件,其标签并返回它们。这个最小的接口足以与dataloader中的并行torch.utils.data一起使用。
第二,如果可以将数据目录重新排列为任一结构,则可以使用DatasetFolder和ImageFolder预先构建的加载程序。这样可以节省一些编码,并自动为来自torchvision.transforms的数据扩展例程提供支持。