如何随机调整Pytorch数据集的标签?

时间:2019-07-27 12:00:43

标签: machine-learning computer-vision dataset pytorch

我是Pytorch的新手,并且在某些技术方面遇到麻烦。我已经使用以下命令下载了MNIST数据集:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

我现在需要对该数据集进行一些实验,但要使用随机标签。如何随机洗牌/重新分配?

我正在尝试手动执行此操作,但是它告诉我“'tuple'对象不支持项目分配”。那我该怎么办呢?

第二个问题:如何从数据集中删除训练点?当我尝试这样做时,它也会给我同样的错误。

谢谢!

1 个答案:

答案 0 :(得分:0)

如果只想改组目标,则可以使用target_transform参数。例如:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: torch.randint(0, 10, (1,)).item(),
                            download=True)

如果您想对数据集进行更精细的调整, 您可以完全包裹mnist

class MyTwistedMNIST(torch.utils.data.Dataset):
  def __init__(self, my_args):
    super(MyTwistedMNIST, self).__init__()
    self.orig_mnist = dset.MNIST(...)  

  def __getitem__(self, index):
    x, y = self.orig_mnist[index]  # get the original item
    my_x = # change input digit image x ?
    my_y = # change the original label y ?
    return my_x, my_y

  def __len__(self):
    return self.orig_mnist.__len__()

如果要完全丢弃原始mnist的某些元素,则可以通过包装原始mnist来返回MyTwistedMNIST小于len的{​​{1}}类,以反映出您要处理的实际mnist示例。此外,您将需要将新的self.orig_mnist.__len__()个示例映射到原始的mnist索引。