我正在使用pytorch构建RNN网络。 数据存储在各种protobuf文件中。 protobuf中的每条记录代表一个带有多个时间戳的训练示例。
由于这是一个非常大的数据集,因此无法读取整个内存中的数据或通过扩展torch.utils.data.Dataset类来随机读取。
根据文档,建议使用torch.utils.data.IterableDataset。
在IterableDataset之上的DataLoader能够实现并行性
但是我无法在自定义数据上找到此实现,文档仅讨论简单的范围迭代器。
import math
import stream
from src import record_pb2
import torch
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, pb_file):
self.pb_file = pb_file
self.start = 0
self.end = 0
# One time read of the data to get the total count of records in the dataset
with stream.open(self.pb_file, 'rb') as data_stream:
for _ in data_stream:
self.end += 1
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # Single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else:
# in a worker process, split the workload
per_worker = int(math.ceil((self.end - self.start))/float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
data_stream = stream.open(self.pb_file, 'rb')
# Block to skip the streaming data till the iter start for the current worker process
i = 0
for _ in data_stream:
i += 1
if i >= iter_start:
break
return iter(self.pb_stream)
我期待一种可以在大型流数据(protobuf)之上设计并行数据馈送器的机制
答案 0 :(得分:0)
__iter__
的{{1}}方法将IterableDataset
一次采样一次数据。在并行设置中,您必须根据worker_id选择样本。对于使用此数据集的yield
,DataLoader
和shuffle
选项将不起作用,因为sampler
将没有任何索引。换句话说,让您的数据集一次生成一个样本,数据加载器将负责加载它们。这个答案吗?