自定义数据集不接受 PyTorch 中的参数

时间:2021-05-08 11:34:54

标签: python computer-vision pytorch dataset pytorch-dataloader

我正在尝试使用 this dataset 在 PyTorch 中创建自定义数据集。它的形状为 (X, 785),X 是样本数,每行包含索引 0 和 784 像素值处的标签。这是我的代码:

from torch.utils.data import Dataset
def SignMNISTDataset(Dataset):

  def __init__(self, csv_file_path, mode='Train'):
    self.labels = []
    self.pixels = []
    self.mode = mode

    data = pd.read_csv(csv_file_path).values
    if self.mode == 'Train':
      self.labels = data[:,0].tolist()
      print("Training labels acquired")

    for idx in range(len(self.labels)):
      self.pixels.append(data[idx][1:].tolist())

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    pixels = self.pixels[idx]
    if self.mode == 'Train':
      labels = self.labels[idx]
      content = {"pixels":pixels, "label":labels}
    else:
      content = {"pixels":pixels}
    return content

training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')

运行时出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-46-0173199f8794> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')
     30 from torch.utils.data import DataLoader
     31 

TypeError: SignMNISTDataset() takes 1 positional argument but 2 were given

这究竟是从哪里来的?在对象创建过程中,模式参数是否以某种方式不被读取? 我的最终目标是创建一个神经网络来对符号字符进行分类,遵循 this tutorial

我尝试在对象创建期间明确提及关键字 mode。这就是我得到的 -

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-48-fd796c48dc67> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', mode='Train')

TypeError: SignMNISTDataset() got an unexpected keyword argument 'mode'

1 个答案:

答案 0 :(得分:2)

请使用

class SignMNISTDataset(Dataset):

代替

def SignMNISTDataset(Dataset):