如何在pytorch中使用大型数据集

时间:2019-02-18 18:59:25

标签: python numpy dataset pytorch large-data

我有一个不适合内存(150G)的巨大数据集,我正在寻找在pytorch中使用它的最佳方法。数据集由数个.npz文件组成,每个文件有10k个样本。我试图建立一个Dataset

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(self.path)
        self.file_length = {}
        for f in self.files:
            # Load file in as a nmap
            d = np.load(os.path.join(self.path, f), mmap_mode='r')
            self.file_length[f] = len(d['y'])

    def __len__(self):
        raise NotImplementedException()

    def __getitem__(self, idx):                
        # Find the file where idx belongs to
        count = 0
        f_key = ''
        local_idx = 0
        for k in self.file_length:
            if count < idx < count + self.file_length[k]:
                f_key = k
                local_idx = idx - count
                break
            else:
                count += self.file_length[k]
        # Open file as numpy.memmap
        d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
        # Actually fetch the data
        X = np.expand_dims(d['X'][local_idx], axis=1)
        y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
        return X, y

但是当实际获取样本时,它花费了30多秒的时间。似乎整个.npz已打开,已存储在RAM中,并且已访问了正确的索引。 如何提高效率?

编辑

这似乎对.npz文件see post有误解,但是有更好的方法吗?

解决方案

正如@covariantmonkey所建议的那样,lmdb是一个不错的选择。目前,由于问题来自.npz文件而不是memmap,我通过将.npz软件包文件拆分为几个.npy文件来对数据集进行了建模。现在,我可以使用相同的逻辑,其中memmap有意义,而且速度非常快(加载样本需要几毫秒)。

1 个答案:

答案 0 :(得分:0)

单个.npz文件有多大?一个月前,我处于类似的困境。各种forum帖子,后来我通过Google搜索lmdb路线。这是我所做的

  1. 将大型数据集分成足够小的文件,我可以放入gpu中-每个文件本质上都是我的微型批处理。在这个阶段,我没有为加载时间优化 just 内存。
  2. 使用key = filenamedata = np.savez_compressed(stff)创建lmdb索引

lmdb会为您处理mmap,并且非常快地加载。

致谢,
A

PS:savez_compessed需要一个字节对象,因此您可以执行类似

的操作
output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb