PyTorch-获取'TypeError:图片应该是PIL图像或ndarray。得到<class'numpy.ndarray'>'错误

时间:2019-06-24 17:07:54

标签: python deep-learning pytorch torch torchvision

当我尝试通过TypeError: pic should be PIL Image or ndarray. Got <class 'numpy.ndarray'>加载非图像数据集时,出现错误DataLoadertorchtorchvision的版本分别为1.0.10.2.2.post3。在3.7.1机器上,Python的版本为Windows 10

代码如下:

class AndroDataset(Dataset):
    def __init__(self, csv_path):
        self.transform = transforms.Compose([transforms.ToTensor()])

        csv_data = pd.read_csv(csv_path)

        self.csv_path = csv_path
        self.features = []
        self.classes = []

        self.features.append(csv_data.iloc[:, :-1].values)
        self.classes.append(csv_data.iloc[:, -1].values)

    def __getitem__(self, index):
        # the error occurs here
        return self.transform(self.features[index]), self.transform(self.classes[index]) 

    def __len__(self):
        return len(self.features)

然后我设置了加载器:

training_data = AndroDataset('android.csv')
train_loader = DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True)

这是完整的错误堆栈跟踪:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1758, in <module>
    main()
  File "C:\Program Files\JetBrains\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1752, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Program Files\JetBrains\PyCharm 2018.1.2\helpers\pydev\pydevd.py", line 1147, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2018.1.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Users/talha/Documents/PyCharmProjects/DeepAndroid/deep_test_conv1d.py", line 231, in <module>
    main()
  File "C:/Users/talha/Documents/PyCharmProjects/DeepAndroid/deep_test_conv1d.py", line 149, in main
    for i, (images, labels) in enumerate(train_loader):
  File "C:\Users\talha\Documents\PyCharmProjects\DeepAndroid\venv\lib\site-packages\torch\utils\data\dataloader.py", line 615, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "C:\Users\talha\Documents\PyCharmProjects\DeepAndroid\venv\lib\site-packages\torch\utils\data\dataloader.py", line 615, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "C:/Users/talha/Documents/PyCharmProjects/DeepAndroid/deep_test_conv1d.py", line 102, in __getitem__
    return self.transform(self.features[index]), self.transform(self.classes[index])
  File "C:\Users\talha\Documents\PyCharmProjects\DeepAndroid\venv\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
    img = t(img)
  File "C:\Users\talha\Documents\PyCharmProjects\DeepAndroid\venv\lib\site-packages\torchvision\transforms\transforms.py", line 91, in __call__
    return F.to_tensor(pic)
  File "C:\Users\talha\Documents\PyCharmProjects\DeepAndroid\venv\lib\site-packages\torchvision\transforms\functional.py", line 50, in to_tensor
    raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
TypeError: pic should be PIL Image or ndarray. Got <class 'numpy.ndarray'>

3 个答案:

答案 0 :(得分:0)

发生这种情况是因为您使用了转换:

<!DOCTYPE html>
<html lang="en">
    <head>
        <meta charset="utf-8">
        <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
        <title>Home</title>

    </head>
<body>
    <div class="head">
        <div class="logo">
            <a href="index.html"><img src="images/logo.png"></a>
        </div>
        <div class="sign-up">
            <a href="#">Login</a>
        </div>
    </div>
    <div class="top-search">
        <div class="custom-container">
            <div class="center-div">

                <div class="banner-form">
                    <form class="">
                        <div class="form-row">
                            <div class="col-5">
                                <input type="text" class="form-control" placeholder="Enter Your Address">
                            </div>
                            <div class="col-3">
                                <input type="text" class="form-control" placeholder="US">
                            </div>
                            <button class="relative br-right pv2" data-test="search-button" aria-label="Search" type="submit">Search</button>
                        </div>
                    </form>
                    <div class="banner-img">
                        <img src="images/banner-home.png" class="img-fluid">
                    </div>
                </div>           
            </div>        
        </div>
    </div>
</body>
</html> 

正如您在documentation中所看到的,self.transform = transforms.Compose([transforms.ToTensor()]) 将PIL图像或torchvision.transforms.ToTensor转换为张量。因此,如果要使用此转换,则数据必须是上述类型之一。

答案 1 :(得分:0)

扩展@MiriamFarber的答案,就不能在transforms.ToTensor()个对象上使用numpy.ndarray。您可以使用torch.from_numpy()numpy数组转换为torch张量,然后将张量转换为所需的数据类型。


例如:

>>> import numpy as np
>>> import torch
>>> np_arr = np.ones((5289, 38))
>>> torch_tensor = torch.from_numpy(np_arr).long()
>>> type(np_arr)
<class 'numpy.ndarray'>
>>> type(torch_tensor)
<class 'torch.Tensor'>

答案 2 :(得分:0)

如果要在numpy数组上使用torchvision.transforms,请先使用transforms.ToPILImage()将numpy数组转换为PIL Image对象