如何自定义pytorch中的数据加载器以处理大型JSON文件中的数据?

时间:2019-03-11 20:13:32

标签: deep-learning time-series pytorch

我正在处理时间序列问题。不同的训练时间序列数据存储在大小为30GB的大型JSON文件中。在张量流中,我知道如何使用TF记录。 pytorch中有类似的方法吗?

2 个答案:

答案 0 :(得分:4)

我想您需要IterableDatasetdocs),因为:

  1. 您可能想遍历没有随机访问权限的文件;
  2. json中的样本数未预先计算。

我已经做了一个最小的用法示例,假设数据集文件的每一行都是一个json本身,但是您可以更改逻辑。

import json
from torch.utils.data import DataLoader, IterableDataset


class JsonDataset(IterableDataset):
    def __init__(self, files):
        self.files = files

    def __iter__(self):
        for json_file in self.files:
            with open(json_file) as f:
                for sample_line in f:
                    sample = json.loads(sample_line)
                    yield sample['x'], sample['time'], ...

...

dataset = JsonDataset(['data/1.json', 'data/2.json', ...])
dataloader = DataLoader(dataset, batch_size=32)

for batch in dataloader:
    y = model(batch)

答案 1 :(得分:1)

通常,您不需要更改/重载默认的data.Dataloader

您应该研究的是如何创建custom data.Dataset
一旦拥有了自己的Dataset,并且知道如何从json文件中逐项提取内容,您就可以将其喂入“香草” data.Dataloader并完成所有批处理/多重处理等操作为您提供的数据集为基础。

例如,如果您有一个包含几个json文件的文件夹,每个文件包含几个示例,则可以有一个Dataset,看起来像:

import bisect

class MyJsonsDataset(data.Dataset):
  def __init__(self, jfolder):
    super(MyJsonsDataset, self).__init__()
      self.filenames = []  # keep track of the jfiles you need to load
      self.cumulative_sizes = [0]  # keep track of number of examples viewed so far
      # this is not actually python code - just pseudo code for you to follow
      for each jsonfile in jfolder:
        self.filenames.append(jsonfile)
        l = number of examples in jsonfile
        self.cumulative_sizes.append(self.cumulative_sizes[-1] + l)
      # discard the first element 
      self.cumulative_sizes.pop(0)

  def __len__(self):
    return self.cumulative_sizes[-1]

  def __getitem__(self, idx):
    # first you need to know wich of the files holds the idx example
    jfile_idx = bisect.bisect_right(self.cumulative_sizes, idx)
    if jfile_idx == 0:
      sample_idx = idx
    else:
      sample_idx = idx - self.cumulative_sizes[jfile_idx - 1]
    # now you need to retrieve the `sample_idx` example from self.filenames[jfile_idx]
    return retrieved_example