Pytorch无法转换numpy.object类型的np.ndarray

时间:2020-05-30 09:20:53

标签: python-3.x pytorch classification

我正在尝试创建一个具有可变图像大小的PyTorch数据加载器。这是我的代码段

def get_imgs(path_to_imgs):

    imgs = []
    for path in path_to_imgs:

        imgs.append(cv2.imread(path))

    imgs = np.asarray(imgs)    

    return imgs   

上面的函数获取路径列表并将图像从路径加载到列表“ imgs”。顺便说一句,图像大小不相等。该列表看起来像imgs = [NumPy数组,NumPy数组....]。但是,当我将列表转换为np.asarray时,会将列表转换为dtype = object。

这是我的数据加载器类

class Dataset(torch.utils.data.Dataset):

  def __init__(self, path_to_imgs, path_to_label):
        'Initialization'
        self.path_to_imgs = path_to_imgs
        self.path_to_label = path_to_label

        self.imgs = get_imgs(path_to_imgs)
        self.label = get_pts(path_to_label)

        self.imgs = torch.Tensor(self.imgs)             **Error here
        # self.imgs = torch.from_numpy(self.imgs)       ** I tried this as well. Same error

        self.label = torch.Tensor(self.label)

        self.len = len(self.imgs)

  def __len__(self):
        'Denotes the total number of samples'
        return self.len

  def __getitem__(self, index):

        return self.imgs, self.label

当我尝试将图像列表转换为张量**时,它失败并给出以下错误

无法转换numpy.object_类型的np.ndarray。唯一受支持的类型为:float64,float32,float16,int64,int32,int16,int8,uint8和bool。

我看过类似的问题herehere,但它们没有帮助。

1 个答案:

答案 0 :(得分:1)

def get_imgs(path_to_imgs):

    imgs = []
    for path in path_to_imgs:
        imgs.append(torch.Tensor(cv2.imread(path)))

    return imgs
class Dataset(torch.utils.data.Dataset):
    def __init__(self, path_to_imgs, path_to_label):
        'Initialization'
        self.path_to_imgs = path_to_imgs
        self.path_to_label = path_to_label

        self.imgs = get_imgs(path_to_imgs)
        self.label = get_pts(path_to_label)

        # padding ops here (https://pytorch.org/docs/stable/nn.html#padding-layers)
        # for img in self.imgs:
        #     ...

        self.label = torch.Tensor(self.label)

        self.len = len(self.imgs)

    def __len__(self):
        'Denotes the total number of samples'
        return self.len

    def __getitem__(self, index):

        return self.imgs, self.label