TFLITE无法将tensorflow模型的输入输出量化为INT8

时间:2021-03-09 17:22:08

标签: tensorflow tensorflow-lite

我在将 Tensorflow 模型转换为 TensorflowLite 时遇到问题。 我想用量化转换整个模型,但是当我完成这一步并可视化 模型的架构我发现输入和输出仍然是Float。 你能帮我解决这个问题吗?

版本信息:tensorflow 2.3.1 / python 3.6

验证数据

validation_generator = validation_datagen.flow_from_directory(
    valid_data_dir,
    target_size=(img_width, img_height),
    classes=classes,
    batch_size=32,
    class_mode='categorical',
    )

模型的架构

model = Sequential()

model.add(Conv2D(32, (3, 3), padding='same', activation='relu', input_shape= (128,128,3)))

model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))


model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.3))

model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))

model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.4))

model.add(Flatten())  
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))

model.summary()

训练后/将模型转换为 Tflite

def representative_dataset_gen():
    for i in range(20):
        data_x, data_y = validation_generator.next()
        for data_xx in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen

converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_tpye  = tf.int8

converter.inference_output_tpye = tf.int8

quantiz_model = converter.convert()

open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)

model properties

1 个答案:

答案 0 :(得分:0)

来自评论

<块引用>

看起来你打错了,应该是 converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8(从 daverim 转述)

将模型转换为 Tflite 的工作代码

def representative_dataset_gen():
    for i in range(20):
        data_x, data_y = validation_generator.next()
        for data_xx in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen

converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_type  = tf.int8

converter.inference_output_type = tf.int8

quantiz_model = converter.convert()

open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)