将BERT模型转换为TFLite

时间:2020-04-01 09:37:39

标签: python tensorflow tensorflow-lite tf-lite bert-language-model

我拥有使用预训练的bert模型构建的用于语义搜索引擎的代码。我想将此模型转换为tflite以便将其部署到Google mlkit。我想知道如何转换它。我想知道是否有可能将其转换为tflite。可能是因为它在官方tensorflow网站上提到:https://www.tensorflow.org/lite/convert。但是我不知道从哪里开始

代码:


from sentence_transformers import SentenceTransformer

# Load the BERT model. Various models trained on Natural Language Inference (NLI) https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/nli-models.md and 
# Semantic Textual Similarity are available https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/sts-models.md

model = SentenceTransformer('bert-base-nli-mean-tokens')

# A corpus is a list with documents split by sentences.

sentences = ['Absence of sanity', 
             'Lack of saneness',
             'A man is eating food.',
             'A man is eating a piece of bread.',
             'The girl is carrying a baby.',
             'A man is riding a horse.',
             'A woman is playing violin.',
             'Two men pushed carts through the woods.',
             'A man is riding a white horse on an enclosed ground.',
             'A monkey is playing drums.',
             'A cheetah is running behind its prey.']

# Each sentence is encoded as a 1-D vector with 78 columns
sentence_embeddings = model.encode(sentences)

print('Sample BERT embedding vector - length', len(sentence_embeddings[0]))

print('Sample BERT embedding vector - note includes negative values', sentence_embeddings[0])

#@title Sematic Search Form

# code adapted from https://github.com/UKPLab/sentence-transformers/blob/master/examples/application_semantic_search.py

query = 'Nobody has sane thoughts' #@param {type: 'string'}

queries = [query]
query_embeddings = model.encode(queries)

# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity
number_top_matches = 3 #@param {type: "number"}

print("Semantic Search Results")

for query, query_embedding in zip(queries, query_embeddings):
    distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]

    results = zip(range(len(distances)), distances)
    results = sorted(results, key=lambda x: x[1])

    print("\n\n======================\n\n")
    print("Query:", query)
    print("\nTop 5 most similar sentences in corpus:")

    for idx, distance in results[0:number_top_matches]:
        print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))

3 个答案:

答案 0 :(得分:0)

首先,您需要在TensorFlow中建立模型,所使用的包是用PyTorch编写的。 Huggingface的Transformers具有TensorFlow模型,您可以从中开始。此外,他们还为Android安装了TFLite-ready models

通常,您首先拥有一个TensorFlow模型。将其保存为SavedModel格式:

tf.saved_model.save(pretrained_model, "/tmp/pretrained-bert/1/")

您可以在此上运行转换器。

答案 1 :(得分:0)

您是否尝试过运行转换工具(tflite_convert),它有什么抱怨吗?

顺便说一句,您可能想查看TFLite团队使用Bert模型的质量检查示例。 https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/android

答案 2 :(得分:0)

如您的示例所示,我找不到有关使用BERT模型在移动设备上获取文档嵌入并计算k个最近文档搜索的任何信息。这样做也不是一个好主意,因为BERT模型的执行成本可能很高,而且具有大量参数,因此模型文件大小也很大(400mb +)。

但是,you can now use BERT和MobileBERT用于在手机上进行文本分类和问题解答。如Xunkai所述,也许您可​​以从与MobileBERT tflite模型对接的demo app开始。我相信在不久的将来,您的用例将会得到更好的支持。