如何将Pytoch数据集克隆到另一个变量中?

时间:2020-02-10 20:23:27

标签: python-3.x pytorch sampling

我想创建Pytorch中提供的MNIST数据集的几个子集。每个子集应具有不同的类。我尝试的是以下内容:

def split_MNIST(mnist_set, digits):
    dset = mnist_set
    classes = []
    indices = dset.targets == digits[0]
    classes.append(dset.classes[digits[0]])
    if len(digits) > 1:
        for digit in digits[1:]:
            idx = dset.targets == digit
            indices = indices + idx
            classes.append(dset.classes[digit])
    dset.targets = dset.targets[indices]
    dset.data = dset.data[indices]
    dset.classes = classes
    return dset


train = datasets.MNIST("../data", train=True, download=True,
                        transform=transforms.Compose([transforms.ToTensor()]))

test =datasets.MNIST("../data", train=False, download=True,
                      transform=transforms.Compose([transforms.ToTensor()]))

tr = split_MNIST(train, [1,2,3])

trainset = torch.utils.data.DataLoader(tr, batch_size=16, shuffle=True)

这有效,但实际上没有更改新的数据集,而是更改了原始火车变量。有没有一种方法可以创建数据集的副本来保留原始副本?

1 个答案:

答案 0 :(得分:0)

只需将数据集实例化放在split_MNIST函数中即可。

def split_MNIST(path2data, train, download, transform, digits):
    dset = datasets.MNIST(path2data, train=train, download=download, transform=transform)
    classes = []
    indices = dset.targets == digits[0]
    classes.append(dset.classes[digits[0]])
    if len(digits) > 1:
        for digit in digits[1:]:
            idx = dset.targets == digit
            indices = indices + idx
            classes.append(dset.classes[digit])
    dset.targets = dset.targets[indices]
    dset.data = dset.data[indices]
    dset.classes = classes
    return dset


transforms = transforms.Compose([transforms.ToTensor()])
tr = split_MNIST('../data', train=True, download=True, transform=transforms, digits=[1,2,3])

trainset = torch.utils.data.DataLoader(tr, batch_size=16, shuffle=True)