pytorch 4d numpy数组在自定义数据集中应用transfroms

时间:2020-06-16 17:06:37

标签: pytorch torchvision

在我的自定义数据集中,我想将transforms.Compose()应用于NumPy数组。

我的图像采用NumPy数组格式,形状为(num_samples, width, height, channels)

如何将以下转换应用于整个numpy数组?

img_transform = transforms.Compose([ transforms.Scale((224,224)), transforms.ToTensor(), transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32]) ])

由于转换接受的是PIL图片而非4-d NumPy数组,因此我的尝试以多个错误结束。

from torchvision import transforms
import numpy as np
import torch

img_transform = transforms.Compose([
        transforms.Scale((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

a = np.random.randint(0,256, (299,299,3))
print(a.shape)

img_transform(a)

1 个答案:

答案 0 :(得分:1)

所有的Torchvision变换都对单个图像而不是成批图像进行操作,因此无法使用4D阵列。

以NumPy数组形式给出的单个图像(如您的代码示例中所示)可以通过将其转换为PIL图像来使用。您只需将transforms.ToPILImage添加到转换管道的开头,即可将张量或NumPy数组转换为PIL图像。

img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

注意:不推荐使用transforms.Scaletransforms.Resize

在您的示例中,您使用了np.random.randint,默认情况下使用的类型是int64,但是图像必须是uint8。像OpenCV这样的库在加载图像时会返回uint8数组。

a = np.random.randint(0,256, (299,299,3), dtype=np.uint8)