如何在本地使用tf-hub模型

时间:2020-07-01 10:27:25

标签: python tensorflow keras tensorflow-hub bert-toolkit

我一直在尝试使用tf-hub https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2中的BERT模型。

import tensorflow_hub as hub
bert_layer = hub.keras_layer('./bert_en_uncased_L-12_H-768_A-12_2', trainable=True)

但是问题在于它每次运行后都会下载数据。

所以我从tf-hub https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2下载了.tar文件

现在我正尝试使用此下载的tar文件(解压缩后)

我已遵循本教程https://medium.com/@xianbao.qian/how-to-run-tf-hub-locally-without-internet-connection-4506b850a915

但是效果不佳,此博客文章中没有提供进一步的信息或脚本

如果有人可以提供完整的脚本以在本地使用下载的模型(不使用互联网),或者可以改善上述博客文章(中)。

我也尝试过

untarredFilePath = './bert_en_uncased_L-12_H-768_A-12_2'
bert_lyr = hub.load(untarredFilePath)
print(bert_lyr)

输出

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7f05c46e6a10>

似乎无效。

或者还有其他方法可以这样做吗?

4 个答案:

答案 0 :(得分:1)

嗯,我无法重现您的问题。对我有用的东西:

script.sh

# download the model file using the 'wget' program
wget "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2?tf-hub-format=compressed"

# rename the downloaded file name to 'tar_file.tar.gz'
mv 2\?tf-hub-format\=compressed tar_file.tar.gz

# extract tar_file.tar.gz to the local directory 
tar -zxvf tar_file.tar.gz

# turn off internet

# run a test script
python3 test.py

# running the last command prints some tensorflow warnings, and then '<tensorflow_hub.keras_layer.KerasLayer object at 0x7fd702a7d8d0>'

test.py

import tensorflow_hub as hub
print(hub.KerasLayer('.'))

答案 1 :(得分:1)

tensorflow_hub库将下载的和未压缩的模型缓存在磁盘上,以避免重复上传。 tensorflow.org/hub/caching的文档已经扩展,可以讨论这种情况和其他情况。

答案 2 :(得分:0)

我使用这篇中等文章(https://medium.com/@xianbao.qian/how-to-run-tf-hub-locally-without-internet-connection-4506b850a915)作为参考编写了此脚本。我正在项目中创建一个缓存目录,并且tensorflow模型被本地缓存在该缓存目录中,并且能够在本地加载模型。希望对您有帮助。

import os
os.environ["TFHUB_CACHE_DIR"] = r'C:\Users\USERX\PycharmProjects\PROJECTX\tf_hub'

import tensorflow as tf
import tensorflow_hub as hub
import hashlib

handle = "https://tfhub.dev/google/universal-sentence-encoder/4"
hashlib.sha1(handle.encode("utf8")).hexdigest()


embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
def get_sentence_embeddings(paragraph_array):
    embeddings=embed(paragraph_array)
    return embeddings

答案 3 :(得分:0)

从tf-hub团队获取信息后,他们将提供此解决方案。 假设您已经从下载按钮的官方tf-hub模型页面下载了.tar.gz文件。 您已经提取了它。您有一个包含资产,变量和模型的文件夹。 您将其放在工作目录中。

在脚本中只需将路径添加到该文件夹​​:

import tensroflow-hub as hub

model_path ='./bert_en_uncased_L-12_H-768_A-12_2' # in my case
# one thing the path you have to provide is for folder which contain assets, variable and model
# not of the model.pb itself

lyr = hub.KerasLayer(model_path, trainable=True)

希望它也对您有用。试试看