如何在PyTorch中为非图像数据创建迷你批处理?

时间:2019-04-16 19:12:10

标签: pytorch

我想加载我的训练和测试数据

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, ))])

trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)


testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

我看到图像数据的实现有没有办法以类似的方式加载非图像数据?

1 个答案:

答案 0 :(得分:0)

我们可以通过以下步骤为此使用torch.utils.data模块

  1. 通过继承torch.utils.data.Dataset

  2. 创建数据集类以加载自定义数据
  3. 通过将数据传递到自定义Dataset类的实例来创建数据集对象

  4. 使用torch.utils.data.DataLoader加载数据集并获取批次

假设您已经从目录中加载了数据,并且在训练和测试numpy数组中,您可以继承torch.utils.data.Dataset类来创建数据集对象

class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
        self.x = x
        self.y = y


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

然后,创建您的数据集对象

traindata = MyDataset(train_x, train_y)

最后,使用DataLoader创建迷你批次

trainloader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)