我正在寻找一种将两个数据集连接到一个的方法,以便可以在一个循环中对其进行训练。然而,批次不允许在数据集之间混合。在以下示例中,批次应仅在 1 到 10 和 41 到 50 范围内:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
df1 = pd.DataFrame(list(range(1,11)))
df2 = pd.DataFrame(list(range(41,51)))
class testset(Dataset):
def __init__(self,data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[0][index]
testdataset1 = testset(df1)
testdataset2 = testset(df2)
datasets = []
datasets.append(testdataset1)
datasets.append(testdataset2)
concat_dataset = ConcatDataset(datasets)
loader = DataLoader(
concat_dataset,
shuffle=False,
num_workers=0,
batch_size=3
)
for data in loader:
print(data)
张量([1, 2, 3])
张量([4, 5, 6])
张量([7, 8, 9])
tensor([10, 41, 42]) ← 那不应该存在
张量([43, 44, 45])
张量([46, 47, 48])
张量([49, 50])
在实际情况中,我结合了两个时间序列,其中两个数据集的值成批重叠会导致一些麻烦......
这不应该是一个想法,对吧?
答案 0 :(得分:0)
如果你这样做了,你就不会再创建随机批次(这些是伪随机的),因为批次元素受到限制(如果第一个元素来自 0
数据集,其余的也必须)。
简短说明:
batch_size
必须指定(因为样本生成依赖于它)length
参数,因为现在这个数据集可以是任意长度(样本是通过 modulo
操作从某个数据集中获取的)0
个数据集开始并从中生成批处理__getitem__
方法内切换):
_new_random_dataset
_next_dataset
下面是一个 torch.utils.data.Dataset
自定义实例,它可以满足您的需求:
class Merger(torch.utils.data.Dataset):
def __init__(
self, *datasets: torch.utils.data.Dataset, batch_size: int, length: int = None
):
self.datasets = datasets
self.batch_size = batch_size
if length is None:
self._len = sum(len(d) for d in self.datasets)
else:
self._len = length
# Keep in internal var how many items we've generated
# Only possible dataset switch when new batch is created
self._items_generated = 0
# First batch will always go from the 0th dataset
self._dataset_index = 0
def __len__(self):
return self._len
def _next_dataset(self):
if self._dataset_index == len(self.datasets) - 1:
self._dataset_index = 0
else:
self._dataset_index += 1
def _new_random_dataset(self):
self._dataset_index = random.randrange(0, len(self.datasets))
def __getitem__(self, index):
if self._items_generated >= self.batch_size:
self._items_generated = 0
# self._next_dataset()
self._new_random_dataset()
self._items_generated += 1
return self.datasets[self._dataset_index][
index % len(self.datasets[self._dataset_index])
]
和示例用法供您验证:
df1 = pd.DataFrame(list(range(1, 11)))
df2 = pd.DataFrame(list(range(41, 51)))
ds = Merger(testset(df1), testset(df2), batch_size=3)
loader = torch.utils.data.DataLoader(ds, shuffle=False, num_workers=0, batch_size=3)
for data in loader:
print(data)