防止PyTorch数据集迭代超出数据集的长度

时间:2019-04-11 16:30:32

标签: python pytorch

我正在将自定义PyTorch数据集用于以下内容:

class ImageDataset(Dataset):
    def __init__(self, input_dir, input_num, input_format, transform=None):
        self.input_num = input_num
        # etc
    def __len__ (self):
        return self.input_num
    def __getitem__(self,idx):
        targetnum = idx % self.input_num
        # etc

但是,当我遍历此数据集时,迭代会循环回到数据集的开头,而不是在数据集的结尾处终止。这实际上会成为迭代器中的无限循环,并且以后的历元永远不会发生历元打印语句。

train_dataset=ImageDataset(input_dir = 'path/to/directory', 
                           input_num = 300, input_format = "mask") # Size 300
num_epochs = 10
for epoch in range(num_epochs):
    print("EPOCH " + str(epoch+1) + "\n")
    num = 0
    for data in train_dataset:
        print(num, end=" ")
        num += 1
        # etc

打印输出(...之间的值):

EPOCH 1
0 1 2 3 4 5 6 7 ... 298 299 300 301 302 303 304 305 ... 597 598 599 600 601 602 603 604 ...

为什么对数据集的基本迭代会持续经过数据集的已定义__len__,以及如何确保使用此方法(或手动进行)后,达到数据集的长度后,对数据集的迭代会终止在数据集长度范围内进行迭代是唯一的解决方案)?

谢谢。

1 个答案:

答案 0 :(得分:0)

Dataset类尚未实现StopIteration信号。

  

for循环侦听StopIteration。 for语句的目的是循环遍历迭代器提供的序列,并且该异常用于表示迭代器现在已完成...

更多:Why does next raise a 'StopIteration', but 'for' do a normal return? | The Iterator Protocol