如何将不适合内存的巨大数据集拆分和加载到pytorch Dataloader中?

时间:2019-09-07 18:17:44

标签: machine-learning deep-learning computer-vision pytorch

我正在训练一个深度学习模型,以便使用Google的Colab在NIH的Chest Xray-14数据集中对疾病进行多标签分类。大约有112k的训练示例和有限的RAM,我无法一次将所有图像加载到Dataloader中。

有没有一种方法可以将图像的路径仅存储在pytorch的DataLoader中,仅读取训练期间当前迭代所需的那些图像,并且一旦迭代完成,就会从内存中卸载图像,依此类推,直到一个时期完成

1 个答案:

答案 0 :(得分:0)

是的,ImageFolder的默认行为是创建图像路径列表并仅在需要时才加载实际图像。它不支持多类标签。但是,您可以编写自己的Dataset以支持多标签,有关详细信息,请参考ImageFolder类。

__init__期间,您将构建图像路径列表和相应的标签列表。仅应在调用__getitem__时加载图像。下面是此类数据集类的存根,详细信息取决于文件的组织,图像类型和标签格式。

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, args):
        """ Construct an indexed list of image paths and labels """

    def __getitem__(self, n):
        """ Load image n in the list of image paths and return it along with its label.
            In the case of multiclass the label will probably be a list of values"""

    def __len__(self):
        """ return the total number of images in this dataset """

创建了有效的数据集实例后,应创建DataLoader的实例,并提供数据集作为参数。 DataLoader负责采样其数据集,即调用您编写的__getitem__方法,并将单个样本放入迷你批中。它还处理并行加载并定义如何对索引进行采样。 DataLoader本身不会存储超出其需求的内容。任何时候它应该在内存中保留的最大样本数为batch_size * num_workers(如果为batch_size,则为num_workers == 0)。