为什么该数据集尝试迭代最后一个元素
from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
def __init__(self, dct):
self.dct = dct
self.mapping = dict(enumerate(dct))
def __getitem__(self, index):
return self.dct[self.mapping[index]]
def __len__(self):
print('called')
return len(self.dct)
ds = DumbDataset({'a': 'aword', 'b': 'another_words'})
for k in ds: print(k)
这会引发KeyError:2,由于对象的长度为2,所以我不理解。一旦迭代器用尽,迭代器是否应该获取StopIteration?
答案 0 :(得分:3)
您的代码引发KeyError
的原因是Dataset
does not implement __iter__()
,因此在for循环Python中使用时会退回到索引{{1 }}并调用0
直到引发__getitem__
,如here所述。您可以修改IndexError
以使其工作,方法是在索引超出范围时使其升高DumbDataset
IndexError
然后是循环
def __getitem__(self, index):
if index >= len(self): raise IndexError
return self.dct[self.mapping[index]]
将按预期工作。另一方面,火炬数据集的典型模板是您可以通过索引遍历它们
for k in ds:
print(k)
或者将它们包装在for i in range(len(ds)):
k = ds[k]
print(k)
中,它会批量返回元素
DataLoader