我最近一直在尝试堆栈语言模型,并发现了一些有趣的东西:BERT和XLNet的输出嵌入与输入嵌入不同。例如,下面的代码片段:
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")
sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)
print(tok.decode(dec.softmax(-1).argmax(-1)))
为我输出:
,,,,,,,,,,,,,,,,,
我本以为输入(输出)令牌嵌入是捆绑在一起的印象,所以我希望返回(格式化的)输入序列。
有趣的是,大多数其他模型都没有表现出这种行为。例如,如果您在GPT2,Albert或Roberta上运行相同的代码段,则会输出输入序列。
这是一个错误吗?还是需要BERT / XLNet?
答案 0 :(得分:3)
不确定是否为时已晚,但是我对您的代码进行了一些试验,可以将其还原。 :)
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")
sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
print("Initial sentence:", sent)
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)
print("Decoded sentence:", tok.decode(dec.softmax(0).argmax(1)))
为此,您将获得以下输出:
Initial sentence: tensor([ 101, 146, 1355, 1106, 1103, 2984, 1103, 1168, 1285, 117,
1122, 1108, 1304, 10703, 1158, 119, 102])
Decoded sentence: [CLS] I went to the store the other day, it was very rewarding. [SEP]