我正在 pytorch 中制作我自己的自定义数据集。
并且,我想将图像可视化。
但是,我认为自定义数据集出现了问题。
请帮帮我。
<块引用>NotImplementedError Traceback(最近一次调用 最后) 在 () 1 导入 matplotlib.pyplot 作为 plt 2 dat=TrainDataset(transforms.ToTensor()) ----> 3 img,label= dat[i] 4 plt.imshow(img.permute(1,2,0))
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py 中 getitem(自我,索引) 31 32 def getitem(self, index) -> T_co: ---> 33 引发 NotImplementedError 34 35 def add(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
NotImplementedError:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
import os
import glob
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
class TrainDataset(Dataset):
def __init__(self, transform):
super().__init__()
self.data = pd.read_csv('/content/drive/MyDrive/cancer/train_labels.csv')
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
img_name, label = self.data.iloc[idx]
img = Image.open(f'/content/drive/MyDrive/cancer/test/{image_name}.tif')
img = self.transform(img)
return (img, torch.tensor(label).long())
import matplotlib.pyplot as plt
dat= TrainDataset(transforms.ToTensor())
img,label= dat[1]
plt.imshow(img.permute(1,2,0))