火炬数据集循环太远

时间:2019-02-11 23:57:31

标签: python pytorch

为什么该数据集尝试迭代最后一个元素

from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
    def __init__(self, dct):
        self.dct = dct
        self.mapping = dict(enumerate(dct))
    def __getitem__(self, index):
        return self.dct[self.mapping[index]]

    def __len__(self):
        print('called')
        return len(self.dct)

ds = DumbDataset({'a': 'aword', 'b': 'another_words'})

for k in ds: print(k)

这会引发KeyError:2,由于对象的长度为2,所以我不理解。一旦迭代器用尽,迭代器是否应该获取StopIteration?

1 个答案:

答案 0 :(得分:3)

您的代码引发KeyError的原因是Dataset does not implement __iter__(),因此在for循环Python中使用时会退回到索引{{1 }}并调用0直到引发__getitem__,如here所述。您可以修改IndexError以使其工作,方法是在索引超出范围时使其升高DumbDataset

IndexError

然后是循环

def __getitem__(self, index):
    if index >= len(self): raise IndexError
    return self.dct[self.mapping[index]]

将按预期工作。另一方面,火炬数据集的典型模板是您可以通过索引遍历它们

for k in ds:
    print(k)

或者将它们包装在for i in range(len(ds)): k = ds[k] print(k) 中,它会批量返回元素

DataLoader