plt.imshow()给出TypeError:无法使用PyTorch将dtype对象的图像数据转换为float

时间:2019-09-15 10:02:06

标签: python image matplotlib pytorch

我正在尝试为一组图像创建自定义数据集处理器。但是,当我尝试查看数据集中的图像时,遇到TypeError的错误:dtype对象的图像数据无法转换为float。

我试图检查是否将PIL图像传递到plt.imshow()函数中。

class DatasetProcessing(Dataset):

    def __init__(self, input_data, output_data, transform=None): 
        self.transform = transform
        self.input_data = 
        input_data.reshape((-1,64,64)).astype(np.float32)[:,:,:,None]
        self.output_data = output_data 

    def __getitem__(self, index): 
        return self.transform(self.input_data[index]), self.output_data[index]

    def __len__(self): 
        return len(list(self.input_data))

transform = transforms.Compose([transforms.ToPILImage()])


dset_train = DatasetProcessing(X_slices_train, Y_train, transform)


train_loader = torch.utils.data.DataLoader(dset_train, batch_size=4, 
                                      shuffle=True, num_workers=4) 

plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
    plt.subplot(1,6,num+1)
    plt.axis('off')
    print(x)
    plt.imshow(np.asarray(x))
    plt.title(y_train[num])

我希望获取我的数据集的图片,但是却收到以下错误消息:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-8b8caac49d97> in <module>
      4     plt.axis('off')
      5     print(x)
----> 6     plt.imshow(np.asarray(x))
      7     plt.title(y_train[num])

~/anaconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, data, **kwargs)
   2675         filternorm=filternorm, filterrad=filterrad, imlim=imlim,
   2676         resample=resample, url=url, **({"data": data} if data is not
-> 2677         None else {}), **kwargs)
   2678     sci(__ret)
   2679     return __ret

~/anaconda3/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1587     def inner(ax, *args, data=None, **kwargs):
   1588         if data is None:
-> 1589             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1590 
   1591         bound = new_sig.bind(ax, *args, **kwargs)

~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
    367                 f"%(removal)s.  If any parameter follows {name!r}, they "
    368                 f"should be pass as keyword, not positionally.")
--> 369         return func(*args, **kwargs)
    370 
    371     return wrapper

~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
    367                 f"%(removal)s.  If any parameter follows {name!r}, they "
    368                 f"should be pass as keyword, not positionally.")
--> 369         return func(*args, **kwargs)
    370 
    371     return wrapper

~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
   5658                               resample=resample, **kwargs)
   5659 
-> 5660         im.set_data(X)
   5661         im.set_alpha(alpha)
   5662         if im.get_clip_path() is None:

~/anaconda3/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
    676                 not np.can_cast(self._A.dtype, float, "same_kind")):
    677             raise TypeError("Image data of dtype {} cannot be converted to "
--> 678                             "float".format(self._A.dtype))
    679 
    680         if not (self._A.ndim == 2

TypeError: Image data of dtype object cannot be converted to float

1 个答案:

答案 0 :(得分:0)

如果正确理解,您的dset_train会产生self.transform(self.input_data[index]), self.output_data[index] self.transform(self.input_data[index])是图像张量(数据),self.output_data[index]是标签,但是在这里:

plt.imshow(np.asarray(x))

您传递的是未包装的x,实际上是(数据,标签)

因此,您需要先将其打开包装:

plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
    data, label = x
    plt.subplot(1,6,num+1)
    plt.axis('off')
    print(x)
    plt.imshow(np.asarray(data))
    plt.title(y_train[num])

编辑:

  

为什么我必须打开x的包装?

您是从PyTorch的{​​{1}}继承而来的,根据docs

  

表示从键到数据样本的映射的所有数据集都应将其子类化。所有子类都应覆盖Dataset,以支持获取给定键的数据样本。

在您定义的__getitem__()DatasetProcessing中,返回两个元组的元组:__getitem__()self.transform(self.input_data[index]),第一个是数据,第二个是适当的标签。这就是为什么您需要像self.output_data[index]那样解压缩,因为data, y = x数据集会产生数据和标签。

  

您可以将我链接到任何文档/教程吗?

我可以向您推荐此链接: