如何在Pytorch中创建类似MNIST的数据集?

时间:2018-12-07 07:42:29

标签: dataset pytorch

我看过MNIST数据集的Pytorch源代码,但它似乎直接从二进制文件读取numpy数组。 我怎样才能像这样创建train_data和train_labels?我已经准备了带有标签的图片和txt。

我已经学会了如何读取图像并标记标签以及如何写get_item和len,真正令我感到困惑的是如何制作 train_data train_labels ,即Torch.Tensor。我试图将它们排列成python列表并转换为torch.Tensor,但失败了:

for index in range(0,len(self.files)):
  fn, label = self.files[index]
  img = self.loader(fn)
  if self.transform is not None:
    img = self.transform(img)
  train_data.append(img)
self.train_data = torch.tensor(train_data)

ValueError:只能将一个元素张量转换为Python标量

1 个答案:

答案 0 :(得分:0)

有两种方法。首先,手册。 Torchvision.datasets声明以下内容:

  

数据集是torch.utils.data.Dataset的子类,即,它们实现了__getitem____len__方法。因此,它们都可以传递给torch.utils.data.DataLoader,后者可以使用torch.multiprocessing工作程序并行加载多个样本。

因此,您只需实现自己的类即可扫描所有图像和标签,保留它们的路径列表(这样就不必将其保留在RAM中)并具有__getitem__方法给定的索引i读取第i个文件,其标签并返回它们。这个最小的接口足以与dataloader中的并行torch.utils.data一起使用。

第二,如果可以将数据目录重新排列为任一结构,则可以使用DatasetFolderImageFolder预先构建的加载程序。这样可以节省一些编码,并自动为来自torchvision.transforms的数据扩展例程提供支持。