我想将模型转换为tflite格式。但是,我不断收到错误消息,不支持运算符BroadcastTo。我能够解决此错误的唯一方法是通过将模型定义为具体函数。我怎么只训练一个具体的功能,甚至有可能吗?
(不是我的实际模型,只是错误的最小示例)
# -------------------- Doesn't Work -------------------- class CustomLayer(tf.keras.layers.Layer): def __init__(self, num_outputs): super(CustomLayer, self).__init__() def call(self, input): trans = tf.ones([1, 25, 37, 12]) trans = tf.math.add(trans, input) m1s = tf.ones([1, 25, 37, 12, 5, 5]) reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1]) f = tf.multiply(reshape, m1s) return f input = tf.keras.Input(shape=(1), dtype=tf.float32) f = CustomLayer(1)(input) model = tf.keras.Model(inputs=input, outputs=f) converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() open("model.tflite", "wb").write(tflite_model)
# -------------------- Concrete Function (Works) -------------------- root = tf.Module() root.var = None @tf.function def example(number): trans = tf.ones([1, 25, 37, 12]) trans = tf.add(trans, number) m1s = tf.ones([1, 25, 37, 12, 5, 5]) reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1]) f = tf.multiply(reshape, m1s) return f root.func = example concrete_func = root.func.get_concrete_function(3) converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) tflite_model = converter.convert() open("model.tflite", "wb").write(tflite_model)
请注意,我已经尝试了以下解决方案:
tf.lite.TFLiteConverter.from_keras_model
tf.lite.TFLiteConverter.from_saved_model
concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
将Keras模型保存为SavedModel并从中获取具体功能我知道也可以做一个自定义运算符,但这需要对tensorflow的C ++ API有深入的了解,知道BroadcastTo在内部如何工作,知道将文件放在哪里,编译自定义AAR,以及建立自定义JNI层。
答案 0 :(得分:0)
试试这个代码!!
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
model_path='/content/model.h5'
model=keras.models.load_model(model_path)
reloaded = keras.models.load_model(model_path,custom_objects{'KerasLayer':hub.KerasLayer})
TFLITE_MODEL = f"path/model.tflite"
# Get the concrete function from the Keras model.
run_model = tf.function(lambda x : reloaded(x))
# Save the concrete function.
concrete_func = run_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
)
# Convert the model to standard TensorFlow Lite model
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converted_tflite_model = converter.convert()
open(TFLITE_MODEL, "wb").write(converted_tflite_model)