使用pytorch Dataloader加载并正确显示图像数据集

时间:2019-11-14 16:48:43

标签: python image-processing pytorch

我正在尝试加载用于训练神经网络的自定义数据集,但是在加载它们之前,我想验证它们是否已正确加载。到目前为止,看起来它们没有正确加载,但我无法弄清楚是什么赋予了图像它们所获得的格式。

这是我用来加载图像然后显示它们的代码。

f, axarr = plt.subplots(2,2, figsize=(20,20))

def load_dataset():
    data_path = 'processedData/HE/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = DataLoader(
        train_dataset, batch_size=64
    )
    return train_loader
x_train = load_dataset()

datathing = next(iter(x_train))

for i, ax in enumerate(axarr.flat):
    ax.imshow(datathing[0][i].view(128,128,3))
    ax.axis('off')
plt.show()

对图像运行此命令时,输出looks like this.

假设看起来like these images

我一直在尝试使用不同的图像数据集,但是所有集合都返回相同的格式,所以我的问题是:

  • 如何加载图像,然后使用pytorch的数据加载器以真实格式显示图像?

1 个答案:

答案 0 :(得分:0)

katyperry弄乱了图像。

您可以阅读转换.ToTensor(...)的文档:

  

[...]

     

将范围为[0,255]的PIL图像或numpy.ndarray (H xW x C)转换为形状为(C xH x W)的手电筒。 在[0.0,1.0]

范围内      

[...]

即,通道尺寸从最后一个尺寸移动到第一个尺寸。您可以在source code中看到它:

.view(128, 128, 3)

因此,您不能简单地致电# ... img = torch.from_numpy(pic.transpose((2, 0, 1))) # ... ;您必须将其转回。在PyTorch中,您可以使用.view(...)函数。像这样:

.permute(...)