我想为Torchvision MNIST数据集实现这种情况,并使用DataLoader
加载数据:
batch A (unaugmented images): 5, 0, 4, ...
batch B (augmented images): 5*, 5+, 5-, 0*, 0+, 0-, 4*, 4+, 4-, ...
...其中,对于A的每个图像,批次B中都有3个增强。len(B)= 3 * len(A)对应。这些批次应在一次迭代中使用,以将批次A的原始图像与批次B中增强的图像进行比较,以建立损失。
class MyMNIST(Dataset):
def __init__(self, mnist_dir, train, augmented, transform=None, repeat=1):
self.mnist_dir = mnist_dir
self.train = train
self.augmented = augmented
self.repeat = repeat
self.transform = transform
self.dataset = None
if augmented and train:
self.dataset = datasets.MNIST(self.mnist_dir, train=train, download=True, transform=transform)
self.dataset.data = torch.repeat_interleave(self.dataset.data, repeats=self.repeat, dim=0)
self.dataset.targets = torch.repeat_interleave(self.dataset.targets, repeats=self.repeat, dim=0)
elif augmented and not train:
raise Exception("Test set should not be augmented.")
else:
self.dataset = datasets.MNIST(MNIST_DIR, train=train, download=True, transform=transform)
使用此类,我想初始化两个不同的数据加载器:
orig_train = MyMNIST(MNIST_DIR, train=True, augmented=False, transform=orig_transforms)
orig_train_loader = torch.utils.data.DataLoader(orig_train.dataset, batch_size=100, shuffle=True)
aug_train = MyMNIST(MNIST_DIR, train=True, augmented=True, transform=aug_transforms, repeat=3)
aug_train_loader = torch.utils.data.DataLoader(aug_train.dataset, batch_size=300, shuffle=True)
我现在的问题是,我还需要在每次迭代中都进行洗牌,同时保持A和B之间的顺序相关。对于上述代码,这是不可能的,因为两个DataLoader
都会产生不同的顺序。因此,我尝试使用单个DataLoader
并手动复制重复的批次:
for batch_no, (images, labels) in enumerate(orig_train_loader):
repeat_images = torch.repeat_interleave(images, 3, dim=0)
这样,我正确地获得了批次B(repeat_images
)的订单,但是现在我缺少了需要在批次/迭代中应用的转换。这似乎不是Pytorch的范式,至少我没有找到做到这一点的方法。
如果有人可以帮助我,我会很高兴-我对Pytorch(以及Stackoverflow)还很陌生,所以也欢迎批评我的整个方法,可能出现的性能问题等。
非常感谢!