Huggingface的Bert的第二个输出是什么意思?

时间:2020-02-15 21:01:25

标签: python deep-learning pytorch huggingface-transformers

通过在拥抱面实现中使用基本BERT模型的原始配置,我得到了长度为2的元组。

import torch

import transformers
from transformers import AutoModel,AutoTokenizer

bert_name="bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_name)
BERT = AutoModel.from_pretrained(bert_name)

e=tokenizer.encode('I am hoping for the best', add_special_tokens=True)

q=BERT(torch.tensor([e]))

print (len(q)) #Output: 2

第一个元素是我希望收到的-每个输入令牌的768维嵌入。

print (e) #Output : [101, 1045, 2572, 5327, 2005, 1996, 2190, 102] 
print (q[0].shape) #Output : torch.Size([1, 8, 768])

但是元组中的第二个元素是什么?

print (q[1].shape) # torch.Size([1, 768])

其大小与每个令牌的编码相同。 那是什么?

也许是[CLS]令牌的副本,它表示整个编码文本的分类?

我们检查一下。

a= q[0][:,0,:]
b=q[1]

print (torch.eq(a,b)) #Output : Tensor([[False, False, False, .... False]])

不!

关于嵌入最后一个令牌的副本(无论出于何种原因)呢?

c= q[0][:,-1,:]
b=q[1]

print (torch.eq(a,c)) #Output : Tensor([[False, False, False, .... False]])

所以,也不是。

文档讨论了如何更改config会导致更多的tuple元素(如隐藏状态),但是我没有发现默认配置输出的此“神秘” tuple元素的任何描述。

关于它的含义和用途的任何想法?

1 个答案:

答案 0 :(得分:3)

在这种情况下,输出是(last_hidden_statepooler_output)的元组。您可以找到有关回报可能是什么 here的文档。