如何从Tensorflow.js(.json)模型转换为Tensorflow(SavedModel)或Tensorflow Lite(.tflite)模型?

时间:2020-06-23 22:38:55

标签: tensorflow tensorflow-lite tensorflow.js tensorflow2

我有一个downloaded来自Google的Tensorflow.js(tfjs)的经过预先训练的PoseNet模型,因此它是一个 json 文件。

但是,我想在Android上使用它,因此我需要.tflite模型。尽管有人将类似的模型从tfjs“移植”到tflite here,但我不知道他们转换了哪种模型(PoseNet有很多变体)。我想自己做。另外,我不想运行有人上传到stackOverflow中的文件中的任意代码:

警告:小心不可信任的代码-TensorFlow模型是代码。有关详细信息,请参见安全使用TensorFlow。 Tensorflow docs

有人知道方便的方法吗?

1 个答案:

答案 0 :(得分:2)

您可以通过查看json文件来找出tfjs格式。通常会说“图形模型”。它们之间的区别是here

从tfjs图形模型到SavedModel(更常见)

使用tfjs-to-tfPatrick Levin

import tfjs_graph_converter.api as tfjs
tfjs.graph_model_to_saved_model(
               "savedmodel/posenet/mobilenet/float/050/model-stride16.json",
               "realsavedmodel"
            )

# Code below taken from https://www.tensorflow.org/lite/convert/python_api
converter = tf.lite.TFLiteConverter.from_saved_model("realsavedmodel")
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('model.tflite', 'wb') as f:
  f.write(tflite_model)

从tfjs图层模型到SavedModel

注意:这仅适用于图层模型格式,不适用于问题中的图形模型格式。我已经写了它们之间的区别here


  1. Install并使用tensorflowjs-convert将.json文件转换为Keras HDF5文件(来自另一个SO thread)。

在Mac上,您会遇到运行pyenv(fix)的问题,而在Z-shell上,pyenv无法正确加载(fix)。另外,一旦pyenv运行,请使用python -m pip install tensorflowjs而不是pip install tensorflowjs,因为pyenv并没有为我改变pip使用的python。

一旦您遵循了tensorflowjs_converter guide,请运行tensorflowjs_converter来验证它是否正确运行,并且应该警告您Missing input_path argument。然后:

tensorflowjs_converter --input_format=tfjs_layers_model --output_format=keras tfjs_model.json hdf5_keras_model.hdf5
  1. 使用TFLiteConverter将Keras HDF5文件转换为SavedModel(标准Tensorflow模型文件)或直接转换为.tflite文件。以下内容在Python文件中运行:
# Convert the model.
model = tf.keras.models.load_model('hdf5_keras_model.hdf5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert() 
    
# Save the TF Lite model.
with tf.io.gfile.GFile('model.tflite', 'wb') as f:
f.write(tflite_model)

或保存到SavedModel:

# Convert the model.
model = tf.keras.models.load_model('hdf5_keras_model.hdf5')
tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, save_format=None,
    signatures=None, options=None
)