如何从tensorflow.js下载模型和权重

时间:2019-01-06 21:16:49

标签: javascript tensorflow machine-learning tensorflow.js

我正在尝试下载一个经过训练的包括权重的tensorflow.js模型,以便以标准方式在tensorflow中以离线方式在python中使用,这是一个不在任何早期阶段的项目的一部分,因此请切换到tensorflow .js是不可能的。 但是我无法弄清楚如何下载这些模型,以及是否有必要进行一些模型转换。

我知道在javascript中我可以访问模型并通过像这样调用它们来使用它们 但是,如果是这样的话,实际上如何冻结.ckpt文件或模型?

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.3"></script>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/posenet@0.2.3"></script>

我的最终目标是获取冻结的模型文件,并像在普通版本的tensorflow中那样获取输出。 另外,它会在离线环境中使用,因此任何在线参考都不会有用。

感谢您的答复

2 个答案:

答案 0 :(得分:2)

可以通过调用模型的方法save保存模型拓扑及其权重。

const model = tf.sequential();
model.add(tf.layers.dense(
     {units: 1, inputShape: [10], activation: 'sigmoid'}));
const saveResult = await model.save('downloads://mymodel'));
// This will trigger downloading of two files:
//   'mymodel.json' and 'mymodel.weights.bin'.
console.log(saveResult);

根据保存模型的位置及其权重,可以使用不同的方案字符串(localStorage,IndexDB等)。 doc

答案 1 :(得分:0)

我去了https://storage.googleapis.com/tfjs-models/,找到了列出所有文件的目录。我找到了相关的文件(我希望所有mobilenet浮动,而不是量化的mobileNet),然后填充此file_uris列表。

base_uri = "https://storage.googleapis.com/tfjs-models/"
file_uris = [
    "savedmodel/posenet/mobilenet/float/050/group1-shard1of1.bin",
    "savedmodel/posenet/mobilenet/float/050/model-stride16.json",
    "savedmodel/posenet/mobilenet/float/050/model-stride8.json",
    "savedmodel/posenet/mobilenet/float/075/group1-shard1of2.bin",
    "savedmodel/posenet/mobilenet/float/075/group1-shard2of2.bin",
    "savedmodel/posenet/mobilenet/float/075/model-stride16.json",
    "savedmodel/posenet/mobilenet/float/075/model-stride8.json",
    "savedmodel/posenet/mobilenet/float/100/group1-shard1of4.bin",
    "savedmodel/posenet/mobilenet/float/100/group1-shard2of4.bin",
    "savedmodel/posenet/mobilenet/float/100/group1-shard3of4.bin",
    "savedmodel/posenet/mobilenet/float/100/model-stride16.json",
    "savedmodel/posenet/mobilenet/float/100/model-stride8.json"
]

然后我使用python将文件迭代下载到相同的文件夹中。

from urllib.request import urlretrieve
import requests
from pathlib import Path

for file_uri in file_uris:
    uri = base_uri + file_uri
    save_path = "/".join(file_uri.split("/")[:-1])
    Path(save_path).mkdir(parents=True, exist_ok=True)
    urlretrieve(uri, file_uri)
    print(path, file_uri)

尝试此代码时,我很喜欢 Jupyter Lab (Jupyter Notebook也很不错)。

有了这个,您将获得一个包含bin文件(权重)和json文件(图形模型)的文件夹。不幸的是,这些是图形模型,因此它们无法转换为SavedModels,因此它们对您绝对没有用。让我知道是否有人找到了在常规TensorFlow(最好是2.0+)中运行这些tfjs图形模型文件的方法。


您还可以从TFHub下载具有“整个”模型的zip文件,例如,here提供了2字节量化的ResNet PoseNet。