如何在没有重叠批次的情况下迭代 pytorch 中的组合数据集?

时间:2021-07-13 04:39:29

标签: python pytorch pytorch-dataloader

我正在寻找一种将两个数据集连接到一个的方法,以便可以在一个循环中对其进行训练。然而,批次不允许在数据集之间混合。在以下示例中,批次应仅在 1 到 10 和 41 到 50 范围内:

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset

df1 = pd.DataFrame(list(range(1,11)))
df2 = pd.DataFrame(list(range(41,51)))

class testset(Dataset):
    def __init__(self,data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[0][index]

testdataset1 = testset(df1)
testdataset2 = testset(df2)

datasets = []
datasets.append(testdataset1)
datasets.append(testdataset2)

concat_dataset = ConcatDataset(datasets)

loader = DataLoader(
    concat_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=3
)

for data in loader:
    print(data)

张量([1, 2, 3])

张量([4, 5, 6])

张量([7, 8, 9])

tensor([10, 41, 42]) ← 那不应该存在

张量([43, 44, 45])

张量([46, 47, 48])

张量([49, 50])

在实际情况中,我结合了两个时间序列,其中两个数据集的值成批重叠会导致一些麻烦......

这不应该是一个想法,对吧?

1 个答案:

答案 0 :(得分:0)

如果你这样做了,你就不会再创建随机批次(这些是伪随机的),因为批次元素受到限制(如果第一个元素来自 0 数据集,其余的也必须)。​​

简短说明:

  • batch_size 必须指定(因为样本生成依赖于它)
  • 可选的 length 参数,因为现在这个数据集可以是任意长度(样本是通过 modulo 操作从某个数据集中获取的)
  • 从第 0 个数据集开始并从中生成批处理
  • 移动到另一个数据集(您可以在 __getitem__ 方法内切换):
    • 随机:方法_new_random_dataset
    • 简单的下一个:方法_next_dataset

下面是一个 torch.utils.data.Dataset 自定义实例,它可以满足您的需求:

class Merger(torch.utils.data.Dataset):
    def __init__(
        self, *datasets: torch.utils.data.Dataset, batch_size: int, length: int = None
    ):
        self.datasets = datasets
        self.batch_size = batch_size

        if length is None:
            self._len = sum(len(d) for d in self.datasets)
        else:
            self._len = length

        # Keep in internal var how many items we've generated
        # Only possible dataset switch when new batch is created
        self._items_generated = 0
        # First batch will always go from the 0th dataset
        self._dataset_index = 0

    def __len__(self):
        return self._len

    def _next_dataset(self):
        if self._dataset_index == len(self.datasets) - 1:
            self._dataset_index = 0
        else:
            self._dataset_index += 1

    def _new_random_dataset(self):
        self._dataset_index = random.randrange(0, len(self.datasets))

    def __getitem__(self, index):
        if self._items_generated >= self.batch_size:
            self._items_generated = 0
            # self._next_dataset()
            self._new_random_dataset()

        self._items_generated += 1
        return self.datasets[self._dataset_index][
            index % len(self.datasets[self._dataset_index])
        ]

和示例用法供您验证:

df1 = pd.DataFrame(list(range(1, 11)))
df2 = pd.DataFrame(list(range(41, 51)))

ds = Merger(testset(df1), testset(df2), batch_size=3)

loader = torch.utils.data.DataLoader(ds, shuffle=False, num_workers=0, batch_size=3)

for data in loader:
    print(data)