我正在尝试加载用于训练神经网络的自定义数据集,但是在加载它们之前,我想验证它们是否已正确加载。到目前为止,看起来它们没有正确加载,但我无法弄清楚是什么赋予了图像它们所获得的格式。
这是我用来加载图像然后显示它们的代码。
f, axarr = plt.subplots(2,2, figsize=(20,20))
def load_dataset():
data_path = 'processedData/HE/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = DataLoader(
train_dataset, batch_size=64
)
return train_loader
x_train = load_dataset()
datathing = next(iter(x_train))
for i, ax in enumerate(axarr.flat):
ax.imshow(datathing[0][i].view(128,128,3))
ax.axis('off')
plt.show()
对图像运行此命令时,输出looks like this.
假设看起来like these images
我一直在尝试使用不同的图像数据集,但是所有集合都返回相同的格式,所以我的问题是:
答案 0 :(得分:0)
katyperry
弄乱了图像。
您可以阅读转换.ToTensor(...)
的文档:
[...]
将范围为[0,255]的PIL图像或numpy.ndarray (H xW x C)转换为形状为(C xH x W)的手电筒。 在[0.0,1.0]
范围内[...]
即,通道尺寸从最后一个尺寸移动到第一个尺寸。您可以在source code中看到它:
.view(128, 128, 3)
因此,您不能简单地致电# ...
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# ...
;您必须将其转回。在PyTorch中,您可以使用.view(...)
函数。像这样:
.permute(...)