如何从TFHub下载的预训练word2vec模型中获取单词向量?

时间:2020-06-10 18:01:20

标签: tensorflow deep-learning nlp word2vec

所以我正在使用TFHub的以下word2vec模型:

embed = hub.load("https://tfhub.dev/google/Wiki-words-250-with-normalization/2")

该对象的类型为:

tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject

虽然我可以使用该模型嵌入文本列表,但是我不清楚如何访问嵌入词本身。

1 个答案:

答案 0 :(得分:0)

首先,让我们讨论一下embed到底是什么?根据{{​​3}},embed对象是基于以TensorFlow 2格式存储的Skipgram模型创建的TextEmbedding。

Skipgram模型只是一个前馈神经网络,它将词汇表中单词的单次编码表示形式作为输入,并计算词嵌入。因此,这些词嵌入未存储在模型中,而是经过计算的。

因此,如果您希望将单词嵌入单独的单词,则可以一次将它们传递一次,如下所示:

# word embedding of `apple`
>>> apple_embedding = embed(["apple"])
>>> apple_embedding.shape
TensorShape([1, 250])

>>> #concatenation of three different word embeddings
>>> group = embed(["apple", "banana", "carrot"])
>>> group.shape
TensorShape([3, 250])