pytorch 自定义数据集未实现错误

时间:2021-04-17 15:31:56

标签: pytorch

我正在 pytorch 中制作我自己的自定义数据集。

并且,我想将图像可视化。

但是,我认为自定义数据集出现了问题。

请帮帮我。

<块引用>

NotImplementedError Traceback(最近一次调用 最后) 在 () 1 导入 matplotlib.pyplot 作为 plt 2 dat=TrainDataset(transforms.ToTensor()) ----> 3 img,label= dat[i] 4 plt.imshow(img.permute(1,2,0))

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py 中 getitem(自我,索引) 31 32 def getitem(self, index) -> T_co: ---> 33 引发 NotImplementedError 34 35 def add(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':

NotImplementedError:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
import os
import glob

from torch.utils.data import Dataset

import pandas as pd
from PIL import Image


class TrainDataset(Dataset):
  def __init__(self, transform):
    super().__init__()
    self.data = pd.read_csv('/content/drive/MyDrive/cancer/train_labels.csv')
    self.transform = transform
    
    def __len__(self):
      return len(self.data)

    def __getitem__(self,idx):
      img_name, label = self.data.iloc[idx]
      img = Image.open(f'/content/drive/MyDrive/cancer/test/{image_name}.tif')
      img = self.transform(img)
      return (img, torch.tensor(label).long())



import matplotlib.pyplot as plt
dat= TrainDataset(transforms.ToTensor())
img,label= dat[1]
plt.imshow(img.permute(1,2,0))

0 个答案:

没有答案
相关问题