使用TensorFlow v2.2将Keras .h5模型转换为TFLite .tflite

时间:2020-05-23 16:17:49

标签: tensorflow keras neural-network tf-lite

我正在尝试将使用Keras定义的网络转换为tflite。网络如下:

model = tf.keras.Sequential([
        # Embedding
        tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]),
        # GRU unit
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        # Fully connected layer
        tf.keras.layers.Dense(vocab_size)
        ])

但是,当我尝试导出到.tflite时,由于GRU层的存在,似乎出了点问题。

# Save trained model in .h5 format
keras_file = 'inference.h5'
tf.keras.models.save_model(model, keras_file)

# Load .h5 model with custom loss function
model = load_model('inference.h5', custom_objects={'loss': loss})

# Converting a tf.Keras model to a TensorFlow Lite model.
converter    = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

有错误:

ValueError: Input 0 of node sequential_8/gru_2/AssignVariableOp was passed float from sequential_8/gru_2/68029:0 incompatible with expected resource.

有解决此问题的解决方案吗?

1 个答案:

答案 0 :(得分:0)

状态暂不支持,您可以尝试设置stateful=False吗?