我正在尝试使用 this dataset 在 PyTorch 中创建自定义数据集。它的形状为 (X, 785),X 是样本数,每行包含索引 0 和 784 像素值处的标签。这是我的代码:
from torch.utils.data import Dataset
def SignMNISTDataset(Dataset):
def __init__(self, csv_file_path, mode='Train'):
self.labels = []
self.pixels = []
self.mode = mode
data = pd.read_csv(csv_file_path).values
if self.mode == 'Train':
self.labels = data[:,0].tolist()
print("Training labels acquired")
for idx in range(len(self.labels)):
self.pixels.append(data[idx][1:].tolist())
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
pixels = self.pixels[idx]
if self.mode == 'Train':
labels = self.labels[idx]
content = {"pixels":pixels, "label":labels}
else:
content = {"pixels":pixels}
return content
training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')
运行时出现以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-46-0173199f8794> in <module>()
27 return content
28
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')
30 from torch.utils.data import DataLoader
31
TypeError: SignMNISTDataset() takes 1 positional argument but 2 were given
这究竟是从哪里来的?在对象创建过程中,模式参数是否以某种方式不被读取? 我的最终目标是创建一个神经网络来对符号字符进行分类,遵循 this tutorial。
我尝试在对象创建期间明确提及关键字 mode
。这就是我得到的 -
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-48-fd796c48dc67> in <module>()
27 return content
28
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', mode='Train')
TypeError: SignMNISTDataset() got an unexpected keyword argument 'mode'
答案 0 :(得分:2)
请使用
class SignMNISTDataset(Dataset):
代替
def SignMNISTDataset(Dataset):