当我使用BERT预测所有令牌时自我索引超出范围

时间:2020-09-30 12:19:13

标签: python pytorch ocr bert-language-model

当我尝试使用具有以下功能的BERT预测文本中的所有标记时:


        def load_predict_BERT(self, masked_text):
            """
            Look for the [MASK] tokens and then attempts to predict the original value of the masked words

            :param masked_text: str
                Text containing [MASK] tokens for each word to predict
            :return: predictions:
            :return: MASKIDS: list
            """
            # Load, train and predict using pre-trained model
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            tokenized_text = tokenizer.tokenize(masked_text)
            indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
            MASKIDS = [i for i, e in enumerate(tokenized_text) if e == '[MASK]']
            # Create the segments tensors
            segs = [i for i, e in enumerate(tokenized_text) if e == "."]
            segments_ids = []
            prev = -1
            for k, s in enumerate(segs):
                segments_ids = segments_ids + [k] * (s - prev)
                prev = s
            segments_ids = segments_ids + [len(segs)] * (len(tokenized_text) - len(segments_ids))
            segments_tensors = torch.tensor([segments_ids])
            # Prepare Torch inputs
            tokens_tensor = torch.tensor([indexed_tokens])
            # Load pre-trained model
            model = BertForMaskedLM.from_pretrained('bert-base-uncased')
            model.resize_token_embeddings(len(tokenizer))
            # Predict all tokens
            with torch.no_grad():
                predictions = model(tokens_tensor, segments_tensors)
            return predictions, MASKIDS, tokenizer

我收到此错误,但无法修复。我是这种方法的初学者,但我希望有人可以帮助我。

Traceback (most recent call last):
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3417, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-c1e97a1e3497>", line 3, in <module>
    text = doc.extract_text(lang='la_Lat')
  File "<ipython-input-2-669f80886eff>", line 79, in extract_text
    prediction, MASKIDS, tokenizer = self.load_predict_BERT(masked_text=masked_text)
  File "<ipython-input-2-669f80886eff>", line 170, in load_predict_BERT
    predictions = model(tokens_tensor, segments_tensors)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/pytorch_pretrained_bert/modeling.py", line 861, in forward
    sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/pytorch_pretrained_bert/modeling.py", line 730, in forward
    embedding_output = self.embeddings(input_ids, token_type_ids)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/pytorch_pretrained_bert/modeling.py", line 269, in forward
    token_type_embeddings = self.token_type_embeddings(token_type_ids)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 124, in forward
    return F.embedding(
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/functional.py", line 1814, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

我知道令牌的大小有问题,但我不知道为什么。

0 个答案:

没有答案