我生成了大约500
个分片的numpy数据文件,每个文件包含大约10000
个数据样本(例如,图片及其标签),例如:
file-000001.npy
file-000002.npy
file-000003.npy
...
file-000500.npy
每个.npy
都包含一个numpy字典,其键和大小为{'image':10000x3x512x64 (dtype=np.float32),'label':10000x100 (dtype=np.float32)}
。请注意,其中一些numpy文件包含少于10000
个样本,例如8111
等。
在训练期间,对于每个时期,我们需要迭代所有500x10000
个样本。由于容量限制,这些数据无法加载到内存中。常见的解决方案是数据预取队列。
我的想法如下:(1)首先记录每个文件中的所有文件名和数据样本的数量,(2)对于每个批次,计算批量索引,然后获取需要的相应数据文件加载到内存中以读取数据样本以组成批处理数据。
在步骤(2)中,如果我们将批量大小设置为256
,我们可能需要读取256
个文件并在每个文件中只读取一个样本来组成批量数据。这可能是缓慢而不实用的。
基于队列,数据加载可能在后台线程上运行,并且所有已读取的批处理数据都保存在队列中(容量可能很大,取决于内存容量)。并且后台线程在队列有空间后一致地读取数据以填充队列。
难以实现吗?我在谷歌搜索过,似乎有一些高级解决方案,例如使用cache
技术,使用mmap
。但我对这些家伙并不熟悉。这有什么简单的例子吗?