如何在PyTorch中显示单个图像?

时间:2018-12-05 00:33:01

标签: python pytorch

我要显示一张图像。它是使用ImageLoader加载的,并存储在PyTorch Tensor中。

当我尝试通过plt.imshow(image)显示它时,我得到:

TypeError: Invalid dimensions for image data

张量的.shape是:

torch.Size([3, 244, 244])

如何显示此PyTorch张量中包含的图像?

6 个答案:

答案 0 :(得分:7)

给出一个表示图像的Tensor,请使用.permute()

plt.imshow(  tensor_image.permute(1, 2, 0)  )

注意:permute does not copy or allocate memory from_numpy() doesn't either.

答案 1 :(得分:6)

如您所见,即使不转换为matplotlib数组,numpy也可以正常工作。但是PyTorch张量(“图像张量”)是第一个通道,因此要将它们与matplotlib一起使用,您需要对其进行重塑:

代码:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

输出:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])

答案 2 :(得分:1)

处理图像数据的 PyTorch 模块需要 C × H × W 格式的张量。1
而 PILlow 和 Matplotlib 需要格式为 H × W × C.2

的图像数组

您可以使用 TorchVision 变换轻松地将张量转换为/from 这种格式:

from torchvision import transforms.functional as F

F.to_pil_image(image_tensor)

或者直接排列坐标轴:

image_tensor.permute(1,2,0)

  1. <块引用>

    处理图像数据的 PyTorch 模块需要将张量布局为 C × H × W:分别为通道、高度和宽度。

  2. <块引用>

    注意我们如何使用 permute 将轴的顺序从 C × H × W 更改为 H × W × C 以匹配什么Matplotlib 期望。

答案 3 :(得分:0)

鉴于图像已按照说明加载并存储在变量image中:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")

matplotlib image tutorial说:

  

在放大照片时经常使用双三次插值-人们倾向于模糊而不是像素化。


或作为Soumith suggested

%matplotlib inline
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

或者,要在弹出窗口中打开图像:

 transforms.ToPILImage()(image).show()

答案 4 :(得分:0)

给出图像路径名img_path的完整示例:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

请注意,transforms.*返回一个函数,这就是为什么将括号括起来的原因。

答案 5 :(得分:0)

使用来自 fastai 的 show_image

from fastai.vision.all import show_image

enter image description here

enter image description here