条件GAN-相同地洗牌/拆分两个数据集

时间:2019-04-22 11:54:17

标签: dataset pytorch dcgan

我正在尝试使用DCGAN训练对某些图像进行着色。这样做时,我将GAN设置为图像的灰度版本。然后,我想先用一批真实图像训练我的GAN /鉴别器,然后再用一批伪图像训练。每隔一段时间,我想比较图像的彩色,灰度和地面真实版本。因此,我需要用相同的方式分割成批的真实/灰色图像。我用pytorch。查看我包含的代码,它们应该具有相同的批次。但是,他们没有。

我尝试了没有worker_init_fn。我也尝试了不同的随机函数调用,并将它们传递给worker_init_fn无济于事。

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

dataloader_gray = torch.utils.data.DataLoader(dataset_gray, batch_size=batch_size,
                                          shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))

for i, (data, data_gray) in enumerate(zip(dataloader, dataloader_gray)):
    doStuff()

1 个答案:

答案 0 :(得分:0)

正如Haran Rajkumar在评论中指出的那样,更好的解决方案将包括预先连接两个数据集,然后再应用torch.utils.DataLoader(前提是两个torch.utils.Dataset对象都包含与图像完全相同顺序的图像)。开始)。

请注意,不必创建单独的类即可执行此操作,torch.utils.data.ConcatDataset开箱即用地提供了此功能。

不确定您的确切代码,但这应该足够(或至少足以使您朝正确的方向):

import torch

dataloader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset(dataset, dataset_gray),
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

for i, (data, data_gray) in enumerate(dataloader):
    doStuff()

如您所见,它更具可读性,更易于使用。