如何从PyTorch中的数据加载器获取整个数据集

时间:2019-08-07 04:13:08

标签: python pytorch dataloader

如何从DataLoader加载整个数据集?我只得到一批数据集。

这是我的代码

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))

3 个答案:

答案 0 :(得分:1)

我不确定您是否要在网络培训之外的其他地方使用数据集(例如检查图像),还是要在培训期间遍历批次。

遍历数据集

或者跟随乌斯曼·阿里(Usman Ali)的回答(可能会溢出),或者您可以这样做

for i in range(len(dataset)): # or i, image in enumerate(dataset)
    images, labels = dataset[i] # or whatever your dataset returns

您能够编写dataset[i]是因为您在__len__类中实现了__getitem__Dataset(只要它是Pytorch Dataset的子类类)。

从数据加载器中获取所有批次

我理解您的问题的方式是,您想检索所有批次以训练网络。您应该了解iter为您提供了数据加载器的迭代器(如果您不熟悉迭代器的概念,请参见wikipedia entry)。 next告诉迭代器给您下一项。

因此,与遍历列表的迭代器相反,数据加载器始终返回下一个项目。列表迭代器在某个时刻停止。我假设您有一些时期,每个时期都有许多步骤。然后您的代码将如下所示

for i in range(epochs):
    # some code
    for j in range(steps_per_epoch):
        images, labels = next(iter(dataloader))
        prediction = net(images)
        loss = net.loss(prediction, labels)
        ...

请注意next(iter(dataloader))。如果要遍历列表,这可能也可以工作,因为Python缓存了对象,但是每次从索引0开始时,您都可能会得到一个新的迭代器。为避免这种情况,请从顶部取出迭代器,如下所示:

iterator = iter(dataloader)
for i in range(epochs):
    for j in range(steps_per_epoch):
        images, labels = next(iterator)

答案 1 :(得分:1)

另一种选择是直接获取整个数据集,而无需使用数据加载器,就像这样:

images, labels = dataset[:]

答案 2 :(得分:0)

如果数据集是火炬batch_size=dataset.__len__(),则可以设置Dataset,否则应该使用batch_szie=len(dataset)之类的东西。

当心,这可能需要大量内存,具体取决于您的数据集。