如何在pytorch中为BERT培训师准备训练数据集?

时间:2020-12-30 02:57:24

标签: pytorch bert-language-model huggingface-transformers

任务是使用 BERT 进行序列分类预训练模型检测文本序列中是否存在血液标签。

class BloodDataset(Dataset):
    """MIMIC Blood dataset."""

    def __init__(self, arff_file):
        """
        Args:
            arff_file (string): Path to the arff file with annotations.
        """
        self.indices, self.contents, self.labels = read_arff(arff_file)
        self.labels = torch.as_tensor(self.labels)
        self.inputs = encode(self.contents)
        self.input_ids = (self.inputs['input_ids'])
        self.attention_mask = (self.inputs['attention_mask'])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if idx in self.indices:
            sample_index = self.indices.index(idx)
            sample = {'index': idx,
                      'content': self.contents[sample_index],
                      'label': self.labels[sample_index],
                      'input_ids': self.input_ids[sample_index],
                      'attention_mask': self.attention_mask[sample_index]
                      }
            return sample
        else:
            return "Sample not found!"

Huggingface 的教程提出了一个训练器解决方案:

    model = BertForSequenceClassification.from_pretrained(model_type)
    training_args = TrainingArguments(
        output_dir='./results',          # output directory
        logging_dir='./logs',            # directory for storing logs
    )
    trainer = Trainer(
        # the instantiated ? Transformers model to be trained
        model=model,
        args=training_args,
        train_dataset=train_dataset,         # training dataset
        eval_dataset=test_dataset            # evaluation dataset
    )
    return trainer

https://huggingface.co/transformers/training.html#trainer

train_dataset = preprocess.BloodDataset("test_blood.arff")
trainer = train.run(train_dataset, train_dataset)
trainer.train()

产生错误:

/.local/lib/python3.6/site-packages/transformers/data/data_collator.py", line 38, in <listcomp>
    features = [vars(f) for f in features]
TypeError: vars() argument must have __dict__ attribute

对于训练器来说,什么是合适的数据集输入?

0 个答案:

没有答案