我是Pytorch的新手,并且在某些技术方面遇到麻烦。我已经使用以下命令下载了MNIST数据集:
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
我现在需要对该数据集进行一些实验,但要使用随机标签。如何随机洗牌/重新分配?
我正在尝试手动执行此操作,但是它告诉我“'tuple'对象不支持项目分配”。那我该怎么办呢?
第二个问题:如何从数据集中删除训练点?当我尝试这样做时,它也会给我同样的错误。
谢谢!
答案 0 :(得分:0)
如果只想改组目标,则可以使用target_transform
参数。例如:
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
target_transform=lambda y: torch.randint(0, 10, (1,)).item(),
download=True)
如果您想对数据集进行更精细的调整,
您可以完全包裹mnist
class MyTwistedMNIST(torch.utils.data.Dataset):
def __init__(self, my_args):
super(MyTwistedMNIST, self).__init__()
self.orig_mnist = dset.MNIST(...)
def __getitem__(self, index):
x, y = self.orig_mnist[index] # get the original item
my_x = # change input digit image x ?
my_y = # change the original label y ?
return my_x, my_y
def __len__(self):
return self.orig_mnist.__len__()
如果要完全丢弃原始mnist的某些元素,则可以通过包装原始mnist来返回MyTwistedMNIST
小于len
的{{1}}类,以反映出您要处理的实际mnist示例。此外,您将需要将新的self.orig_mnist.__len__()
个示例映射到原始的mnist索引。