分析Bert输出以进行序列分类

时间:2020-09-09 17:19:27

标签: text-classification bert-language-model

有人可以帮助我了解用于序列分类的最后一个隐藏层中BERT的输出吗?我正在对一个用于序列分类的界面模型进行测试,并按如下方式访问输出:

sequence = "this coffee tastes bad"
tok_sequence = tokenizer(sequence)

output = model(**tok_sequence)       # model inference
output[1][0]                         # accessed by setting config.output_hidden_states = True
output[1][0].shape                   # --> (1, 7, 768) --> | CLS | this | coffee | taste | ##s | bad | SEP |

在字符串“这杯咖啡味道不好”中,我基本上是获取最后一个隐藏层的输出,在这种情况下,该隐藏层的形状为(1,7,768),[CLS] + 5个单词标记+ [SEP] ,然后遍历每个令牌,求和它们的值(768)并计算平均值。得出的总数将在下面的图表图像中输出。

this

Output of BERT model

任何帮助口译都将有所帮助。我要寻找的方向是提取最有意义的单词或对输出情感影响最大的单词。在查看输出时(尤其是第二个示例),似乎正值不是那么重要的单词,而负值是有影响的单词。对于大多数输出​​示例,这似乎都是正确的。但是,我可能对此并不了解。

任何其他建议的方法也将有助于文本分类。

0 个答案:

没有答案