使用pytorch的torch.utils.data.IterableDataset在protobuf文件顶部的数据加载器

时间:2019-08-28 04:39:23

标签: protocol-buffers pytorch recurrent-neural-network

我正在使用pytorch构建RNN网络。 数据存储在各种protobuf文件中。 protobuf中的每条记录代表一个带有多个时间戳的训练示例。

由于这是一个非常大的数据集,因此无法读取整个内存中的数据或通过扩展torch.utils.data.Dataset类来随机读取。

根据文档,建议使用torch.utils.data.IterableDataset。

在IterableDataset之上的

DataLoader能够实现并行性

但是我无法在自定义数据上找到此实现,文档仅讨论简单的范围迭代器。

import math

import stream
from src import record_pb2
import torch

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, pb_file):
        self.pb_file = pb_file
        self.start = 0
        self.end = 0
        # One time read of the data to get the total count of records in the dataset
        with stream.open(self.pb_file, 'rb') as data_stream:
            for _ in data_stream:
                self.end += 1

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # Single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:
            # in a worker process, split the workload
            per_worker = int(math.ceil((self.end - self.start))/float(worker_info.num_workers))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        data_stream = stream.open(self.pb_file, 'rb')

        # Block to skip the streaming data till the iter start for the current worker process
        i = 0
        for _ in data_stream:
            i += 1
            if i >= iter_start:
                break

        return iter(self.pb_stream)

我期待一种可以在大型流数据(protobuf)之上设计并行数据馈送器的机制

1 个答案:

答案 0 :(得分:0)

__iter__的{​​{1}}方法将IterableDataset一次采样一次数据。在并行设置中,您必须根据worker_id选择样本。对于使用此数据集的yieldDataLoadershuffle选项将不起作用,因为sampler将没有任何索引。换句话说,让您的数据集一次生成一个样本,数据加载器将负责加载它们。这个答案吗?