我有一个不适合内存(150G)的巨大数据集,我正在寻找在pytorch中使用它的最佳方法。数据集由数个.npz
文件组成,每个文件有10k个样本。我试图建立一个Dataset
类
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length = {}
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
但是当实际获取样本时,它花费了30多秒的时间。似乎整个.npz
已打开,已存储在RAM中,并且已访问了正确的索引。
如何提高效率?
编辑
这似乎对.npz
文件see post有误解,但是有更好的方法吗?
解决方案
正如@covariantmonkey所建议的那样,lmdb是一个不错的选择。目前,由于问题来自.npz
文件而不是memmap
,我通过将.npz
软件包文件拆分为几个.npy
文件来对数据集进行了建模。现在,我可以使用相同的逻辑,其中memmap
有意义,而且速度非常快(加载样本需要几毫秒)。
答案 0 :(得分:0)
单个.npz
文件有多大?一个月前,我处于类似的困境。各种forum帖子,后来我通过Google搜索lmdb路线。这是我所做的
key = filename
和data = np.savez_compressed(stff)
创建lmdb索引 lmdb
会为您处理mmap,并且非常快地加载。
致谢,
A
PS:savez_compessed
需要一个字节对象,因此您可以执行类似
output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb