PyTorch数据集中使用的len函数在哪里?

时间:2018-02-04 13:31:03

标签: python machine-learning deep-learning conv-neural-network pytorch

我希望使用here中的代码。 但是,我正在查看方框5,其中有以下功能;

def __len__(self):
    # Default epoch size is 10 000 samples
    return 10000

我没有在这个脚本中看到使用此功能的任何地方。 对此的澄清将不胜感激。

另外,我想确定用于训练这个卷积神经网络的图像补丁的数量。这个len函数是否与补丁数量相关联?

1 个答案:

答案 0 :(得分:1)

这是Dataset类的函数。 __len__()函数指定数据集的大小。在引用的代码中,在框10中,初始化数据集并将其传递给DataLoader对象:

train_set = ISPRS_dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)

您会看到DataLoader传递数据集对象以及批量大小。然后,DataLoader对象使用数据集的__len__函数来创建批处理。这发生在方框13中,它在DataLoader上进行迭代。