我有两个数据加载器,我想在不重新定义数据集的情况下合并它们,在我的例子中是 train_dataset 和 val_dataset。
train_loader = DataLoader(train_dataset, batch_size = 512, drop_last=True,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = 512, drop_last=False)
想要的结果:
train_loader = train_loader + val_loader
答案 0 :(得分:3)
数据加载器是迭代器,您可以实现一个函数,该函数返回一个迭代器,该迭代器产生数据加载器的内容,一个接一个的数据加载器。
给定多个迭代器 itrs
,它将迭代每个迭代器,然后迭代每个迭代器,一次产生一批。一种可能的实现方式很简单:
def itr_merge(*itrs):
for itr in itrs:
for v in itr:
yield v
这是一个用法示例:
>>> dl1 = DataLoader(TensorDataset(torch.zeros(5, 1)), batch_size=2, drop_last=True)
>>> dl2 = DataLoader(TensorDataset(torch.ones(10, 1)), batch_size=2)
>>> for x in itr_merge(dl1, dl2):
>>> print(x)
[tensor([[0.], [0.]])]
[tensor([[0.], [0.]])]
[tensor([[1.], [1.]])]
[tensor([[1.], [1.]])]
[tensor([[1.], [1.]])]
[tensor([[1.], [1.]])]
[tensor([[1.], [1.]])]