pytorch数据加载器“ nan”返回值

时间:2018-10-27 21:41:23

标签: python machine-learning pytorch torch vgg-net

我正在基于VGG的GAN进行工作,它基于一个名为CartoonGAN:https://github.com/znxlwm/pytorch-CartoonGAN的现有项目。总体而言,该模型不是针对损失函数进行训练并返回“ nan”。我很确定这是因为数据加载存在问题。原始项目的作者定义了一个“ data_load”函数来包装pytorch的“ DataLoader”函数。发表在下面:

def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=True):
    dset = datasets.ImageFolder(path, transform)
    print("DSET IS "+str(dset))
    ind = dset.class_to_idx[subfolder]
    print("IND IS " + str(ind))

    n = 0
    for i in range(dset.__len__()):
        print("for loop on interation ", str(i))
        print("dset.img is "+str(dset.imgs))
        #print("Current item: ", str(dset.__getitem__(i)))
        if ind != dset.imgs[n][1]:
            del dset.imgs[n]
            n -= 1

        n += 1

    return torch.utils.data.DataLoader(dset, 
           batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

我们认为返回的data_loader对象为空或格式错误。我们认为所有基于包装器的参数都是正确的,并且该错误在函数内发生。

我们无法确定失败的原因,并且我们想知道是否有人对此有所了解。

谢谢!

0 个答案:

没有答案