可以训练具体的功能吗?

时间:2019-08-18 06:47:21

标签: tensorflow-lite tensorflow2.0

我想将模型转换为tflite格式。但是,我不断收到错误消息,不支持运算符BroadcastTo。我能够解决此错误的唯一方法是通过将模型定义为具体函数。我怎么只训练一个具体的功能,甚至有可能吗?

(不是我的实际模型,只是错误的最小示例)


    # -------------------- Doesn't Work --------------------

    class CustomLayer(tf.keras.layers.Layer):
      def __init__(self, num_outputs):
        super(CustomLayer, self).__init__()

      def call(self, input):
        trans = tf.ones([1, 25, 37, 12])
        trans = tf.math.add(trans, input)
        m1s = tf.ones([1, 25, 37, 12, 5, 5])
        reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
        f = tf.multiply(reshape, m1s)
        return f

    input = tf.keras.Input(shape=(1), dtype=tf.float32)
    f = CustomLayer(1)(input)
    model = tf.keras.Model(inputs=input, outputs=f)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    open("model.tflite", "wb").write(tflite_model)


    # -------------------- Concrete Function (Works) --------------------

    root = tf.Module()
    root.var = None

    @tf.function
    def example(number):
      trans = tf.ones([1, 25, 37, 12])
      trans = tf.add(trans, number)
      m1s = tf.ones([1, 25, 37, 12, 5, 5])
      reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
      f = tf.multiply(reshape, m1s)
      return f

    root.func = example
    concrete_func = root.func.get_concrete_function(3)
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    tflite_model = converter.convert()
    open("model.tflite", "wb").write(tflite_model)

请注意,我已经尝试了以下解决方案:

  1. Keras 中定义模型(以便轻松进行训练)并使用
    tf.lite.TFLiteConverter.from_keras_model
  2. 将Keras模型另存为 SavedModel 并使用
    tf.lite.TFLiteConverter.from_saved_model
  3. 使用
    concrete_func = model.signatures[
    tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    将Keras模型保存为SavedModel并从中获取具体功能

我知道也可以做一个自定义运算符,但这需要对tensorflow的C ++ API有深入的了解,知道BroadcastTo在内部如何工作,知道将文件放在哪里,编译自定义AAR,以及建立自定义JNI层。

1 个答案:

答案 0 :(得分:0)

试试这个代码!!

import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub

model_path='/content/model.h5'
model=keras.models.load_model(model_path)
reloaded = keras.models.load_model(model_path,custom_objects{'KerasLayer':hub.KerasLayer})

TFLITE_MODEL = f"path/model.tflite"


# Get the concrete function from the Keras model.
run_model = tf.function(lambda x : reloaded(x))

# Save the concrete function.
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)

# Convert the model to standard TensorFlow Lite model
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converted_tflite_model = converter.convert()
open(TFLITE_MODEL, "wb").write(converted_tflite_model)