如何从本地系统加载TF集线器模型

时间:2020-03-07 14:54:02

标签: python tensorflow

一种方法是每次从tensorflow_hub下载模型,如下所示

import tensorflow as tf
import tensorflow_hub as hub

hub_url = "https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1"
embed = hub.KerasLayer(hub_url)
embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
print(embeddings.shape, embeddings.dtype)

我想一次下载该文件,然后一次又一次地使用,而不必每次都下载

4 个答案:

答案 0 :(得分:1)

您可以使用hub.load()方法来加载TF Hub模块。另外,docs说,

当前,仅TensorFlow 2.x和 通过调用tensorflow.saved_model.save()创建的模块。的 该方法适用于渴望模式和图形模式。

hub.load方法具有参数handle。模块句柄的类型是

  1. 智能URL解析器,例如tfhub.dev,例如:https://tfhub.dev/google/nnlm-en-dim128/1

  2. Tensorflow支持的文件系统上的目录,其中包含模块文件。其中可能包括本地目录(例如/usr/local/mymodule)或Google Cloud Storage存储桶(gs://mymodule)。

  3. 指向模块的TGZ存档的URL,例如https://example.com/mymodule.tar.gz

您可以使用第二点和第三点。

答案 1 :(得分:1)

也许其他人可能会从具体的、可重复的答案中受益。这篇文章对应于这个specific tfhub model

tensorflow_hub 版本:0.12.0
tensorflow 版本:2.2.0

我在我的 Linux 服务器上设置了以下路径:

tf2

(出于各种原因,我们仍然对 Tensorflow 1.x 有一些需求,所以我认为根据模型是否设计用于 tensorflow 1.x 与 tensorflow 2.x 来分离模型可能是个好主意,因此我路径中的 # bash tar xzf bert_en_uncased_L-12_H-768_A-12_4.tar.gz )

然后我下载了模型文件,把它推送到我的Linux服务器,放在上面的位置,然后执行:

# python
import os
os.listdir("/opt/tfhub/tf2/bert_en_uncased_L-12_H-768_A-12_4/")
>>> ['keras_metadata.pb', 'saved_model.pb', 'assets', 'variables']

这给了我以下文件:

# python
import tensorflow_hub as tfhub
import tensorflow as tf
bert_layer = tfhub.KerasLayer(tfhub.load("/opt/tfhub/tf2/bert_en_uncased_L-12_H-768_A-12_4"))

然后我可以像这样加载模型:

require 'google/apis/oauth2_v2'
service = Google::Apis::Oauth2V2::Oauth2Service.new
service.tokeninfo(access_token: "Your Access Token").email

答案 2 :(得分:1)

如果有人想知道模型在 Windows 上的默认保存位置,就像我一样,它在这里。

<块引用>

C:\Users\AvrakDavra\AppData\Local\Temp\tfhub_modules\

显然,您可以在任何地方下载并提及该路径和 tfhub 将从那里获取,但以防万一。 立即打开 Windows 上的临时文件夹。

  1. 按 Windows 按钮+R
  2. 写入 %TEMP%

它会为您的用户名打开临时文件夹,默认情况下是 tfhub_modules 文件夹。它将包含如下文件夹

enter image description here

文本文件的内容与下面类似。

Module: https://tfhub.dev/google/universal-sentence-encoder/4 Download Time: 2021-07-17 18:17:09.714147 Downloader Hostname: LAPTOP(PID:12720)

答案 3 :(得分:0)

  1. 从url +“?tf-hub-format = compressed”下载模型
    例如“ https://tfhub.dev/google/tf2-preview/nnlm-zh-dim128/1?tf-hub-format=compressed”
  2. 解压缩
  3. 在代码
    中加载未压缩的文件夹
import tensorflow as tf
import tensorflow_hub as hub

embed = hub.KerasLayer('path/to/untarred/folder')
embeddings = embed(["A long sentence.", "single-word", "http://example.com"])
print(embeddings.shape, embeddings.dtype)