我正在基于VGG的GAN进行工作,它基于一个名为CartoonGAN:https://github.com/znxlwm/pytorch-CartoonGAN的现有项目。总体而言,该模型不是针对损失函数进行训练并返回“ nan”。我很确定这是因为数据加载存在问题。原始项目的作者定义了一个“ data_load”函数来包装pytorch的“ DataLoader”函数。发表在下面:
def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=True):
dset = datasets.ImageFolder(path, transform)
print("DSET IS "+str(dset))
ind = dset.class_to_idx[subfolder]
print("IND IS " + str(ind))
n = 0
for i in range(dset.__len__()):
print("for loop on interation ", str(i))
print("dset.img is "+str(dset.imgs))
#print("Current item: ", str(dset.__getitem__(i)))
if ind != dset.imgs[n][1]:
del dset.imgs[n]
n -= 1
n += 1
return torch.utils.data.DataLoader(dset,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
我们认为返回的data_loader对象为空或格式错误。我们认为所有基于包装器的参数都是正确的,并且该错误在函数内发生。
我们无法确定失败的原因,并且我们想知道是否有人对此有所了解。
谢谢!