在我的自定义数据集中,我想将transforms.Compose()
应用于NumPy数组。
我的图像采用NumPy数组格式,形状为(num_samples, width, height, channels)
。
如何将以下转换应用于整个numpy数组?
img_transform = transforms.Compose([
transforms.Scale((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
])
由于转换接受的是PIL图片而非4-d NumPy数组,因此我的尝试以多个错误结束。
from torchvision import transforms
import numpy as np
import torch
img_transform = transforms.Compose([
transforms.Scale((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
])
a = np.random.randint(0,256, (299,299,3))
print(a.shape)
img_transform(a)
答案 0 :(得分:1)
所有的Torchvision变换都对单个图像而不是成批图像进行操作,因此无法使用4D阵列。
以NumPy数组形式给出的单个图像(如您的代码示例中所示)可以通过将其转换为PIL图像来使用。您只需将transforms.ToPILImage
添加到转换管道的开头,即可将张量或NumPy数组转换为PIL图像。
img_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
])
注意:不推荐使用transforms.Scale
的transforms.Resize
。
在您的示例中,您使用了np.random.randint
,默认情况下使用的类型是int64,但是图像必须是uint8。像OpenCV这样的库在加载图像时会返回uint8数组。
a = np.random.randint(0,256, (299,299,3), dtype=np.uint8)