我正在尝试使用DCGAN训练对某些图像进行着色。这样做时,我将GAN设置为图像的灰度版本。然后,我想先用一批真实图像训练我的GAN /鉴别器,然后再用一批伪图像训练。每隔一段时间,我想比较图像的彩色,灰度和地面真实版本。因此,我需要用相同的方式分割成批的真实/灰色图像。我用pytorch。查看我包含的代码,它们应该具有相同的批次。但是,他们没有。
我尝试了没有worker_init_fn。我也尝试了不同的随机函数调用,并将它们传递给worker_init_fn无济于事。
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))
dataloader_gray = torch.utils.data.DataLoader(dataset_gray, batch_size=batch_size,
shuffle=True, num_workers=workers, worker_init_fn = random.seed(seed))
for i, (data, data_gray) in enumerate(zip(dataloader, dataloader_gray)):
doStuff()
答案 0 :(得分:0)
正如Haran Rajkumar在评论中指出的那样,更好的解决方案将包括预先连接两个数据集,然后再应用torch.utils.DataLoader
(前提是两个torch.utils.Dataset
对象都包含与图像完全相同顺序的图像)。开始)。
请注意,不必创建单独的类即可执行此操作,torch.utils.data.ConcatDataset
开箱即用地提供了此功能。
不确定您的确切代码,但这应该足够(或至少足以使您朝正确的方向):
import torch
dataloader = torch.utils.data.DataLoader(
torch.utils.data.ConcatDataset(dataset, dataset_gray),
batch_size=batch_size,
shuffle=True,
num_workers=workers
)
for i, (data, data_gray) in enumerate(dataloader):
doStuff()
如您所见,它更具可读性,更易于使用。