BERT调用函数中的关键字参数

时间:2020-03-24 21:57:34

标签: tensorflow nlp arguments huggingface-transformers

在HuggingFace TensorFlow 2.0 BERT库中,documentation指出:

TF 2.0模型接受两种格式作为输入:

  • 将所有输入作为关键字参数(例如PyTorch模型),或

  • 在第一个位置将所有输入作为列表,元组或字典 论点。

我正在尝试使用这两个中的第一个调用我创建的BERT模型:

from transformers import BertTokenizer, TFBertModel
import tensorflow as tf

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')

text = ['This is a sentence', 
        'The sky is blue and the grass is green', 
        'More words are here']
labels = [0, 1, 0]
tokenized_text = bert_tokenizer.batch_encode_plus(batch_text_or_text_pairs=text,
                                                  pad_to_max_length=True,
                                                  return_tensors='tf')
dataset = tf.data.Dataset.from_tensor_slices((tokenized_text['input_ids'],
                                              tokenized_text['attention_mask'],
                                              tokenized_text['token_type_ids'],
                                              tf.constant(labels))).batch(3)
sample = next(iter(dataset))

result1 = bert_model(inputs=(sample[0], sample[1], sample[2]))  # works fine
result2 = bert_model(inputs={'input_ids': sample[0], 
                             'attention_mask': sample[1], 
                             'token_type_ids': sample[2]})  # also fine
result3 = bert_model(input_ids=sample[0], 
                     attention_mask=sample[1], 
                     token_type_ids=sample[2])  # raises an error

但是当我执行最后一行时,出现错误:

TypeError: __call__() missing 1 required positional argument: 'inputs'

有人可以解释一下如何正确使用输入的关键字参数样式吗?

1 个答案:

答案 0 :(得分:1)

如果您不仅仅将一个张量用作第一个参数,他们似乎在内部将inputs解释为input_ids。您可以在TFBertModel中看到此内容,然后寻找TFBertMainLayer的{​​{1}}函数。

对于我来说,如果执行以下操作,我将得到与callresult1完全相同的结果:

result2

或者,您也可以放下result3 = bert_model(inputs=sample[0], attention_mask=sample[1], token_type_ids=sample[2]) ,效果也一样。

相关问题