我正在尝试为一组图像创建自定义数据集处理器。但是,当我尝试查看数据集中的图像时,遇到TypeError的错误:dtype对象的图像数据无法转换为float。
我试图检查是否将PIL图像传递到plt.imshow()函数中。
class DatasetProcessing(Dataset):
def __init__(self, input_data, output_data, transform=None):
self.transform = transform
self.input_data =
input_data.reshape((-1,64,64)).astype(np.float32)[:,:,:,None]
self.output_data = output_data
def __getitem__(self, index):
return self.transform(self.input_data[index]), self.output_data[index]
def __len__(self):
return len(list(self.input_data))
transform = transforms.Compose([transforms.ToPILImage()])
dset_train = DatasetProcessing(X_slices_train, Y_train, transform)
train_loader = torch.utils.data.DataLoader(dset_train, batch_size=4,
shuffle=True, num_workers=4)
plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
plt.subplot(1,6,num+1)
plt.axis('off')
print(x)
plt.imshow(np.asarray(x))
plt.title(y_train[num])
我希望获取我的数据集的图片,但是却收到以下错误消息:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-10-8b8caac49d97> in <module>
4 plt.axis('off')
5 print(x)
----> 6 plt.imshow(np.asarray(x))
7 plt.title(y_train[num])
~/anaconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, data, **kwargs)
2675 filternorm=filternorm, filterrad=filterrad, imlim=imlim,
2676 resample=resample, url=url, **({"data": data} if data is not
-> 2677 None else {}), **kwargs)
2678 sci(__ret)
2679 return __ret
~/anaconda3/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
1587 def inner(ax, *args, data=None, **kwargs):
1588 if data is None:
-> 1589 return func(ax, *map(sanitize_sequence, args), **kwargs)
1590
1591 bound = new_sig.bind(ax, *args, **kwargs)
~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
367 f"%(removal)s. If any parameter follows {name!r}, they "
368 f"should be pass as keyword, not positionally.")
--> 369 return func(*args, **kwargs)
370
371 return wrapper
~/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*args, **kwargs)
367 f"%(removal)s. If any parameter follows {name!r}, they "
368 f"should be pass as keyword, not positionally.")
--> 369 return func(*args, **kwargs)
370
371 return wrapper
~/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
5658 resample=resample, **kwargs)
5659
-> 5660 im.set_data(X)
5661 im.set_alpha(alpha)
5662 if im.get_clip_path() is None:
~/anaconda3/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
676 not np.can_cast(self._A.dtype, float, "same_kind")):
677 raise TypeError("Image data of dtype {} cannot be converted to "
--> 678 "float".format(self._A.dtype))
679
680 if not (self._A.ndim == 2
TypeError: Image data of dtype object cannot be converted to float
答案 0 :(得分:0)
如果正确理解,您的dset_train
会产生self.transform(self.input_data[index]), self.output_data[index]
self.transform(self.input_data[index])
是图像张量(数据),self.output_data[index]
是标签,但是在这里:
plt.imshow(np.asarray(x))
您传递的是未包装的x
,实际上是(数据,标签)
因此,您需要先将其打开包装:
plt.figure(figsize = (16, 4))
for num, x in enumerate(dset_train):
data, label = x
plt.subplot(1,6,num+1)
plt.axis('off')
print(x)
plt.imshow(np.asarray(data))
plt.title(y_train[num])
编辑:
为什么我必须打开x的包装?
您是从PyTorch的{{1}}继承而来的,根据docs:
表示从键到数据样本的映射的所有数据集都应将其子类化。所有子类都应覆盖
Dataset
,以支持获取给定键的数据样本。
在您定义的__getitem__()
类DatasetProcessing
中,返回两个元组的元组:__getitem__()
和self.transform(self.input_data[index])
,第一个是数据,第二个是适当的标签。这就是为什么您需要像self.output_data[index]
那样解压缩,因为data, y = x
数据集会产生数据和标签。
您可以将我链接到任何文档/教程吗?
我可以向您推荐此链接: