我在将 Tensorflow 模型转换为 TensorflowLite 时遇到问题。 我想用量化转换整个模型,但是当我完成这一步并可视化 模型的架构我发现输入和输出仍然是Float。 你能帮我解决这个问题吗?
版本信息:tensorflow 2.3.1 / python 3.6
validation_generator = validation_datagen.flow_from_directory(
valid_data_dir,
target_size=(img_width, img_height),
classes=classes,
batch_size=32,
class_mode='categorical',
)
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', activation='relu', input_shape= (128,128,3)))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.3))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.4))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))
model.summary()
def representative_dataset_gen():
for i in range(20):
data_x, data_y = validation_generator.next()
for data_xx in data_x:
data = tf.reshape(data, shape=[-1, 128, 128, 3])
yield [data]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_tpye = tf.int8
converter.inference_output_tpye = tf.int8
quantiz_model = converter.convert()
open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)
答案 0 :(得分:0)
来自评论
<块引用>看起来你打错了,应该是 converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
(从 daverim 转述)
将模型转换为 Tflite 的工作代码
def representative_dataset_gen():
for i in range(20):
data_x, data_y = validation_generator.next()
for data_xx in data_x:
data = tf.reshape(data, shape=[-1, 128, 128, 3])
yield [data]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
quantiz_model = converter.convert()
open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)