从3D张量中提取特定令牌的嵌入

时间:2019-12-17 07:44:47

标签: tensorflow keras

每个训练示例都包含一个句子(标记列表),一个起始标记索引和一个结束标记索引。开始和结束索引突出显示了句子中的某些单词。标签为True或False,表示突出显示的单词是否“有趣”。
例如,在此示例中,“新闻”是突出显示的单词。

tokens=["This", "news", "is", "sad"], start=1, end=2, label=1 

这是我的模型设置。 句子被馈送到预先训练的bert层,该层返回一个3D张量,其中包含在每个句子中嵌入每个标记的

我想从3D张量中提取每个方言词的高亮词的嵌入,并获取它们的均值,以便可以将其馈送到密集层中,我该怎么做?假设我已经有了3D张量。

我正在使用由tensorflow 2.0提供的keras API。 我是tensorflow的新手,所以非常感谢一个具体的例子。

谢谢!

0 个答案:

没有答案