Pytorch:具有多个损失的多个数据集

时间:2019-04-17 10:48:44

标签: pytorch

我正在使用多个数据集。我有多个损失,每个损失都必须在这些数据集的子集上进行评估。我想从每个数据集中生成一个批次,并评估其所有适当批次上的每个损失。一些损耗是成对的(需要加载相应数据点对),而其他损耗是在单个数据点上计算的。我需要以易于添加新数据集的方式进行设计。是否有任何内置的pytorch会对此有所帮助?在pytorch中设计此的最佳方法是什么?预先感谢。

1 个答案:

答案 0 :(得分:0)

不清楚您的设置到底是什么。
但是,您可以有多个Dataset实例,每个数据集一个。
在数据集之上,您可以实现“标​​记数据集”,即为所有样本添加“标签”的数据集:

class TaggedDataset(data.Dataset):
  def __init__(dataset, tag):
    super(TaggedDataset, self).__init__()
    self.ds_ = dataset
    self.tag_ = tag

  def __len__(self):
    return len(self.ds_)

  def __getitem__(self, index):
    return self.ds_[index], self.tag_

为每个数据集提供不同的tag,将所有数据集合并为一个ConcatDataset,然后将常规DataLoader包裹起来。

现在,在您的训练代码中

for input, label, tag in my_tagged_loader:
  # process each input according to the dataset tag it got.