我正在尝试将使用Keras定义的网络转换为tflite。网络如下:
model = tf.keras.Sequential([
# Embedding
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]),
# GRU unit
tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
# Fully connected layer
tf.keras.layers.Dense(vocab_size)
])
但是,当我尝试导出到.tflite时,由于GRU层的存在,似乎出了点问题。
# Save trained model in .h5 format
keras_file = 'inference.h5'
tf.keras.models.save_model(model, keras_file)
# Load .h5 model with custom loss function
model = load_model('inference.h5', custom_objects={'loss': loss})
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
有错误:
ValueError: Input 0 of node sequential_8/gru_2/AssignVariableOp was passed float from sequential_8/gru_2/68029:0 incompatible with expected resource.
有解决此问题的解决方案吗?
答案 0 :(得分:0)
状态暂不支持,您可以尝试设置stateful=False
吗?