有什么方法可以提取Google通用句子编码器的详尽词汇表吗?

时间:2019-03-13 18:13:50

标签: tensorflow tensorflow-hub

我要为其创建嵌入的某些句子,除非在句子中包含某些真正不寻常的单词,否则它对于相似性搜索非常有用。

在那种情况下,真正不寻常的词实际上包含了句子中任何词的最相似信息,但由于该词显然不在模型的词汇表中,因此所有信息在嵌入过程中都会丢失

我想获取GUSE嵌入模型已知的所有单词的列表,以便我可以将这些已知单词从句子中屏蔽掉,而仅保留“新颖”单词。

然后我可以对目标语料库中的那些新颖单词进行精确的单词搜索,并获得相似句子搜索的可用性。

例如“我喜欢使用Xapian!”被嵌入为“我喜欢使用UNK”。

如果仅使用关键字搜索“ Xapian”而不是语义相似性搜索,那么与使用GUSE和向量KNN相比,我将获得更多相关的结果。

关于如何提取GUSE已知/使用的词汇的任何想法?

1 个答案:

答案 0 :(得分:0)

我假设您已经安装了tensorflow和tensorflow_hub,并且您已经下载了模型。

重要:我假设您正在使用https://tfhub.dev/google/universal-sentence-encoder/4!无法保证不同版本的对象图看起来都一样,很可能需要修改。

找到它在磁盘上的位置-除非您设置/tmp/tfhub_modules环境变量(Windows / Mac具有不同的位置),否则它在TFHUB_CACHE_DIR的某个位置。该路径应包含一个名为saved_model.pb的文件,该文件是使用协议缓冲区序列化的模型。

不幸的是,该字典是在模型的协议缓冲区文件中序列化的,而不是作为外部资产进行序列化的,因此我们必须加载模型并从中获取变量。

策略是使用tensorflow的代码对文件进行反序列化,然后沿着序列化的对象树一直向下到达字典。

import importlib

MODEL_PATH = 'path/to/model/dir' # e.g. '/tmp/tfhub_modules/063d866c06683311b44b4992fd46003be952409c/'

# Use the tensorflow internal Protobuf loader. A regular import statement will fail.
loader_impl = importlib.import_module('tensorflow.python.saved_model.loader_impl')

saved_model = loader_impl.parse_saved_model(MODEL_PATH)

# reach into the object graph to get the tensor
graph = saved_model.meta_graphs[0].graph_def
function = graph.library.function
node_type, node_value = function[5].node_def
# if you print(node_type) you'll see it's called "text_preprocessor/hash_table"
# as well as get insight into this branch of the object graph we're looking at
words_tensor = node_value.attr.get("value").tensor

word_list = [i.decode('utf-8') for i in words_tensor.string_val]
print(len(word_list)) # -> 400004

一些有用的资源:

  1. GitHub issue有关的词汇
  2. 与此问题相关联的Tensorflow Google-group thread

额外说明

尽管GitHub问题可能会引起您的思考,但这里的40万个单词不是GloVe 400k词汇。您可以通过下载GloVe 6B embeddings (file link),提取glove.6B.50d.txt,然后使用以下代码比较两个字典来进行验证:

with open('/path/to/glove.6B.50d.txt') as f:
    glove_vocabulary = set(line.strip().split(maxsplit=1)[0] for line in f)

USE_vocabulary = set(word_list) # from above

print(len(USE_vocabulary - glove_vocabulary)) # -> 281150

检查不同的词汇本身很有趣,例如GloVe为什么要输入“ 287.9”?