转换器库中BertModel中hidden_​​states元组的内容如何排列

时间:2020-04-15 11:49:22

标签: pytorch python-3.7 huggingface-transformers

model = BertModel.from_pretrained('bert-base-uncased', config=BertConfig.from_pretrained('bert-base-uncased',output_hidden_states=True))
outputs = model(input_ids) 
hidden_states = outputs[2]

hidden_​​states是一个13 torch.FloatTensors的元组。每个张量的大小为:(batch_size, sequence_length, hidden_size)。 根据文档,13张量是嵌入和12编码器层的隐藏状态。

我的问题:

hidden_states[0]嵌入层,而hidden_states[12]是第12编码层或

hidden_states[0]嵌入层,而hidden_states[12]是第一编码层还是

hidden_states[0]的第12编码层,而hidden_states[12]是嵌入层或

hidden_states[0]的第一编码层,而hidden_states[12]是嵌入层

我没有发现其他地方对此有明确的说明。

1 个答案:

答案 0 :(得分:0)

看看BertModel的source-code,可以得出以下结论:hidden_​​states [0]包含初始嵌入层的输出,而元组中的其余元素包含的隐藏状态按每个递增的顺序层。简而言之,hidden_​​states [1]包含BERT第一层的输出,hidden_​​states [12]包含最后一层,即第12层。