在下一个句子预测任务上微调bert

时间:2020-09-13 12:05:37

标签: python bert-language-model

在下一个句子预测任务上,我正在尝试使用Huggingface库对Bert进行微调。我看了教程,并尝试使用DataCollatorForNextSentencePredictionTextDatasetForNextSentencePrediction。当我使用它时,出现以下错误。我已经提供了下面的代码。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-7678758b2c9c> in <module>()
     56   train(bert_model,bert_tokenizer,train_data_set_path)
     57   #prepare_data_set(bert_tokenizer)
---> 58 main()

9 frames
<ipython-input-18-7678758b2c9c> in main()
     54   bert_model = BertForNextSentencePrediction.from_pretrained("bert-base-cased")
     55   train_data_set_path = "/content/drive/My Drive/next_sentence/line_data_set_file.txt"
---> 56   train(bert_model,bert_tokenizer,train_data_set_path)
     57   #prepare_data_set(bert_tokenizer)
     58 main()

<ipython-input-18-7678758b2c9c> in train(bert_model, bert_tokenizer, path, eval_path)
     47 
     48   )
---> 49   trainer.train()
     50   trainer.save_model(out_dir)
     51 def main():

/usr/local/lib/python3.6/dist-packages/transformers/trainer.py in train(self, model_path, trial)
    697 
    698             epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
--> 699             for step, inputs in enumerate(epoch_iterator):
    700 
    701                 # Skip past any already trained steps if resuming training

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    401     def _next_data(self):
    402         index = self._next_index()  # may raise StopIteration
--> 403         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    404         if self._pin_memory:
    405             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     45         else:
     46             data = self.dataset[possibly_batched_index]
---> 47         return self.collate_fn(data)

/usr/local/lib/python3.6/dist-packages/transformers/data/data_collator.py in __call__(self, examples)
    356         for i, doc in enumerate(examples):
    357             input_id, segment_id, attention_mask, label = self.create_examples_from_document(doc, i, examples)
--> 358             input_ids.extend(input_id)
    359             segment_ids.extend(segment_id)
    360             attention_masks.extend(attention_mask)

/usr/local/lib/python3.6/dist-packages/transformers/data/data_collator.py in create_examples_from_document(self, document, doc_index, examples)
    444                         random_document = examples[random_document_index]
    445                         random_start = random.randint(0, len(random_document) - 1)
--> 446                         for j in range(random_start, len(random_document)):
    447                             tokens_b.extend(random_document[j])
    448                             if len(tokens_b) >= target_b_length:

/usr/lib/python3.6/random.py in randint(self, a, b)
    219         """
    220 
--> 221         return self.randrange(a, b+1)
    222 
    223     def _randbelow(self, n, int=int, maxsize=1<<BPF, type=type,

/usr/lib/python3.6/random.py in randrange(self, start, stop, step, _int)
    197             return istart + self._randbelow(width)
    198         if step == 1:
--> 199             raise ValueError("empty range for randrange() (%d,%d, %d)" % (istart, istop, width))
    200 
    201         # Non-unit step argument supplied.

ValueError: empty range for randrange() (0,0, 0) 
    def train(bert_model,bert_tokenizer,path,eval_path=None):
      out_dir = "/content/drive/My Drive/next_sentence/"
      training_args = TrainingArguments(
        output_dir=out_dir,
        overwrite_output_dir=True,
        num_train_epochs=1,
        per_device_train_batch_size=30,
        save_steps=10000,
        save_total_limit=2,
      )

  

data_collator = DataCollatorForNextSentencePrediction(
        tokenizer=bert_tokenizer,mlm=False,block_size=512,nsp_probability =0.5
      )
      
      dataset = TextDatasetForNextSentencePrediction(
        tokenizer = bert_tokenizer,
        file_path=path,
        block_size=512,
      )
       
      trainer = Trainer(
          model=bert_model,
          args=training_args,
          train_dataset=dataset,
          data_collator=data_collator,
          
      )
      trainer.train()
      trainer.save_model(out_dir)
    def main():
      print("Running main")
      bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
      bert_model = BertForNextSentencePrediction.from_pretrained("bert-base-cased")
      train_data_set_path = "/content/drive/My Drive/next_sentence/line_data_set_file.txt"
      train(bert_model,bert_tokenizer,train_data_set_path)
      #prepare_data_set(bert_tokenizer)
    main() 

0 个答案:

没有答案