如何通过PyTorch中的标签数量调整批次数据

时间:2020-07-16 07:28:49

标签: pytorch

我为文档分类制作了n-grams / doc-id,

def create_dataset(tok_docs, vocab, n):
  n_grams = []
  document_ids = []
  for i, doc in enumerate(tok_docs):
    for n_gram in [doc[0][i:i+n] for i in range(len(doc[0]) - 1)]:
       n_grams.append(n_gram)
       document_ids.append(i)
  return n_grams, document_ids

def create_pytorch_datasets(n_grams, doc_ids):
  n_grams_tensor = torch.tensor(n_grams)
  doc_ids_tensor = troch.tensor(doc_ids)
  full_dataset = TensorDataset(n_grams_tensor, doc_ids_tensor)
  return full_dataset

create_dataset返回一对(n-grams,document_ids),如下所示:

n_grams, doc_ids = create_dataset( ... )
train_data = create_pytorch_datasets(n_grams, doc_ids)
>>> train_data[0:100]
(tensor([[2076, 517, 54, 3647, 1182, 7086],
         [517, 54, 3647, 1182, 7086, 1149],
         ...
         ]),
 tensor(([0, 0, 0, 0, 0, ..., 3, 3, 3]))

train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True)

第一个张量内容表示n-grams,第二个张量内容表示doc_id

但是,正如您所知,根据文档的长度,根据标签的培训数据量会发生变化。

如果一个文档的长度很长,那么会有很多对在培训数据中带有标签。

我认为这可能会导致模型过度拟合,因为分类模型倾向于将输入分类为长文档。

因此,我想从label(doc_ids)的均匀分布中提取输入批次。如何在上面的代码中修复它?

p.s) 如果有如下所示的train_data,我想以这样的概率提取批处理:

  n-grams        doc_ids
([1, 2, 3, 4],      1)       ====> 0.33
([1, 3, 5, 7],      2)       ====> 0.33
([2, 3, 4, 5],      3)       ====> 0.33 * 0.25
([3, 5, 2, 5],      3)       ====> 0.33 * 0.25
([6, 3, 4, 5],      3)       ====> 0.33 * 0.25
([2, 3, 1, 5],      3)       ====> 0.33 * 0.25 

1 个答案:

答案 0 :(得分:1)

在pytorch中,您可以为数据加载器指定samplerbatch_sampler,以更改数据点采样的方式。

有关数据加载器的文档: https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler

有关采样器的文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler

例如,您可以使用WeightedRandomSampler为每个数据点指定权重。例如,权重可以是文档的反长度。

我会在代码中进行以下修改:

def create_dataset(tok_docs, vocab, n):
  n_grams = []
  document_ids = []
  weights = []  # << list of weights for sampling
  for i, doc in enumerate(tok_docs):
    for n_gram in [doc[0][i:i+n] for i in range(len(doc[0]) - 1)]:
       n_grams.append(n_gram)
       document_ids.append(i)
       weights.append(1/len(doc[0]))  # << ngrams of long documents are sampled less often
  return n_grams, document_ids, weights

sampler = WeightedRandomSampler(weights, 1, replacement=True) # << create the sampler

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, sampler=sampler)  # << includes the sampler in the dataloader