如何改组数据集的标签?

时间:2019-10-23 16:14:46

标签: pytorch mnist

我已经使用以下命令下载了MNIST数据集:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

我现在需要在此数据集(MNIST)上进行一些实验,但改组训练集的标签。如何随机洗牌/重新分配?我尝试了以下方法:

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: torch.randint(0, 10, (1,)).item(),
                            download=True)

但是我注意到在训练过程中,lambda函数之后使标签混洗的原因是:他们在每个时代都在变化。这样,我将无法达到100%的训练精度。如何以完全随机的方式随机播放这些标签,以确保这些标签在训练过程中不会改变?

谢谢!

1 个答案:

答案 0 :(得分:1)

如果您的目标是创建标签的随机映射,则在定义目标转换以保持转换恒定之前,需要先定义映射。像下面这样的东西应该可以解决问题

import random
label_mapping = list(range(10))
random.shuffle(label_mapping)
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            target_transform=lambda y: label_mapping[y],
                            download=True)

为了在每个时期重新进行洗牌,您需要重新定义每个时期的标签映射,训练数据集和数据加载器。

更新要生成一个独立于真实标签但与给定索引一致的随机标签,则可能需要做一些非常仔细的播种或重新实现数据集类的某些功能。

例如,后一种情况可能看起来像这样

import random
class RandomMNIST(dsets.MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.targets = [random.randint(0, 9) for _ in range(len(self.data))]

train_dataset = RandomMNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

或等效地

import random
train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)
train_dataset.targets = [random.randint(0, 9) for _ in range(len(train_dataset))]