如果事先不知道训练样本的顺序和总数,如何创建自定义PyTorch数据集?

时间:2019-02-07 10:36:06

标签: deep-learning pytorch

我有一个42 GB的jsonl文件。该文件的每个元素都是一个json对象。我从每个json对象创建训练样本。但是我提取的每个json对象的训练样本数量可以在0到5个样本之间变化。在不读取内存中整个jsonl文件的情况下创建自定义PyTorch数据集的最佳方法是什么?

这是我正在谈论的数据集-Google Natural Questions

1 个答案:

答案 0 :(得分:5)

您有两种选择。

  1. 如果没有很多小文件,最简单的选择是将每个json对象预处理为单个文件。然后,您可以根据请求的索引阅读每个索引。例如
   
    class SingleFileDataset(Dataset):
        def __init__(self, list_of_file_paths):
            self.list_of_file_paths = list_of_file_paths

        def __getitem__(self, index):
            return np.load(self.list_of_file_paths[index]) # Or equivalent reading code for single file
  1. 您还可以将数据拆分为一定数量的文件,然后在给定索引的情况下计算样本所驻留的文件。然后,您需要将该文件打开到内存中并读取适当的索引。这在磁盘访问和内存使用之间进行了权衡。假设您有n个样本,并且在预处理过程中我们将样本平均分成c个文件。现在,要读取索引为i的示例,我们将做
   
    class SplitIntoFilesDataset(Dataset):
        def __init__(self, list_of_file_paths, n_splits):
            self.list_of_file_paths = list_of_file_paths
            self.n_splits = n_splits

        def __getitem__(self, index):
            # index // n_splits is the relevant file, and 
            # index % len(self) is the index in in that file
            file_to_load = self.list_of_file_paths[index // self.n_splits]
            # Load file
            file = np.load(file)
            datapoint = file[index % len(self)]
  1. 最后,您可以使用一个HDF5文件,该文件允许访问磁盘上的行。如果您有大量数据,这可能是最好的解决方案,因为数据将在磁盘上关闭。我在下面复制了一个实现here

    import h5py
    import torch
    import torch.utils.data as data
    class H5Dataset(data.Dataset):
    
        def __init__(self, file_path):
            super(H5Dataset, self).__init__()
            h5_file = h5py.File(file_path)
            self.data = h5_file.get('data')
            self.target = h5_file.get('label')
    
        def __getitem__(self, index):            
            return (torch.from_numpy(self.data[index,:,:,:]).float(),
                    torch.from_numpy(self.target[index,:,:,:]).float())
    
        def __len__(self):
            return self.data.shape[0]