我要显示一张图像。它是使用ImageLoader
加载的,并存储在PyTorch Tensor
中。
当我尝试通过plt.imshow(image)
显示它时,我得到:
TypeError: Invalid dimensions for image data
张量的.shape
是:
torch.Size([3, 244, 244])
如何显示此PyTorch张量中包含的图像?
答案 0 :(得分:7)
给出一个表示图像的Tensor
,请使用.permute()
:
plt.imshow( tensor_image.permute(1, 2, 0) )
注意:permute
does not copy or allocate memory和 from_numpy()
doesn't either.
答案 1 :(得分:6)
如您所见,即使不转换为matplotlib
数组,numpy
也可以正常工作。但是PyTorch张量(“图像张量”)是第一个通道,因此要将它们与matplotlib
一起使用,您需要对其进行重塑:
代码:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
输出:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
答案 2 :(得分:1)
处理图像数据的 PyTorch 模块需要 C × H × W 格式的张量。1
而 PILlow 和 Matplotlib 需要格式为 H × W × C.2
您可以使用 TorchVision 变换轻松地将张量转换为/from 这种格式:
from torchvision import transforms.functional as F
F.to_pil_image(image_tensor)
或者直接排列坐标轴:
image_tensor.permute(1,2,0)
处理图像数据的 PyTorch 模块需要将张量布局为 C × H × W:分别为通道、高度和宽度。
注意我们如何使用 permute
将轴的顺序从 C × H × W 更改为 H × W × C 以匹配什么Matplotlib 期望。
答案 3 :(得分:0)
鉴于图像已按照说明加载并存储在变量image
中:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
在放大照片时经常使用双三次插值-人们倾向于模糊而不是像素化。
%matplotlib inline
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
或者,要在弹出窗口中打开图像:
transforms.ToPILImage()(image).show()
答案 4 :(得分:0)
给出图像路径名img_path
的完整示例:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
请注意,transforms.*
返回一个函数,这就是为什么将括号括起来的原因。
答案 5 :(得分:0)