pytorch二进制分类中如何处理不平衡类

时间:2020-09-28 16:12:20

标签: python pytorch bert-language-model huggingface-transformers

我正在研究二进制文本分类问题。我该如何应用smote或WeightedRandomSample来解决数据集中的不平衡问题。我的代码目前看起来像这样:

class GDataset(Dataset):

  def __init__(self, passage, targets, tokenizer, max_len):
    self.passage = passage
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.passage)
  
  def __getitem__(self, item):
    passage = str(self.passage[item])
    target = self.targets[item]
    
    if (target == 1) and self.transform: # minority class
            x = self.transform(x)

    encoding = self.tokenizer.encode_plus(
      passage,
      add_special_tokens=True,
      max_length=self.max_len,
      return_token_type_ids=False,
      pad_to_max_length=True,
      return_attention_mask=True,
      return_tensors='pt',
    )
    return
      'passage_text': passage,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long

我如何使用其他平衡技术?

0 个答案:

没有答案