如何在pytorch MNIST数据集中选择特定标签

时间:2019-09-12 20:05:57

标签: python pytorch

我正在尝试仅使用PyTorch Mnist数据集中的特定数字创建数据加载器

我已经尝试创建自己的采样器,但是它不起作用,并且我不确定是否正确使用了蒙版。

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, mask):

        self.mask = mask


    def __iter__(self):

        return (self.indices[i] for i in torch.nonzero(self.mask))


    def __len__(self):

        return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)   

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)   

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)

到目前为止,我有很多不同类型的错误。对于此实现,它是“停止迭代”。 我觉得这很容易/愚蠢,但是我找不到简单的方法来做。 谢谢您的帮助!

3 个答案:

答案 0 :(得分:1)

我能想到的最简单的选择是就地减少数据集:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]

答案 1 :(得分:0)

迭代器用尽时会引发

StopIteration。您确定口罩工作正常吗?似乎您传递了布尔值列表,但是torch.nonzero会期望使用浮点数或整数。

您应该写:

mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]

您还需要将数据集传递给采样器,例如:

sampler = YourSampler(dataset, mask=mask)

具有此类定义

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, mask):

        self.mask = mask
        self.dataset = dataset
...

有关更多详细信息,请参阅pytorch文档(显示源代码)以了解他们如何实现更多高级采样器:https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html#SequentialSampler

答案 2 :(得分:0)

感谢您的帮助。 一段时间后,我想出了一个解决方案(但可能不是最好的解决方案):

#define _CRT_SECURE_NO_WARNINGS 1