全量化不会将int8数据更改为将模型输入层更改为int8

时间:2020-09-10 13:17:55

标签: python tensorflow keras quantization tf-lite

我正在将keras h5模型量化为uint8。为了获得完整的uint8量化,用户dtlam26this post中告诉我,代表性数据集应该已经在uint8中,否则输入层仍在float32中。

问题是,如果我输入uint8数据,则在调用converter.convert()时会收到以下错误消息

ValueError:无法设置张量:得到类型为INT8的张量,但预期 输入FLOAT32作为输入178,名称:input_1

似乎,该模型仍然期望float32。所以我用

检查了基本的keras_vggface预训练模型(from here
from keras_vggface.vggface import VGGFace
import keras

pretrained_model = VGGFace(model='resnet50', include_top=False, input_shape=(224, 224, 3), pooling='avg')  # pooling: None, avg or max

pretrained_model.save()

,生成的h5模型具有带float32的输入层。 接下来,我使用uint8作为输入dtype更改了模型定义:

def RESNET50(include_top=True, weights='vggface',
             ...)

    if input_tensor is None:
        img_input = Input(shape=input_shape, dtype='uint8')

但是对于int只允许使用int32。但是,使用int32会导致问题,以下各层都需要float32。

这似乎不是手动对所有图层进行操作的正确方法。

为什么在量化过程中我的模型不只包含uint8数据,而是自动将输入更改为uint8?

我想念什么?你知道解决方案吗?非常感谢。

1 个答案:

答案 0 :(得分:0)

dtlam26用户的解决方案

尽管该模型仍然不能与Google NNAPI一起运行,但感谢delan:使用int8和int8输出量化模型的解决方案是使用TF 1.15.3或TF2.2.0。

...
converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_model_dir + modelname) 
        
def representative_dataset_gen():
  for _ in range(10):
    pfad='pathtoimage/000001.jpg'
    img=cv2.imread(pfad)
    img = np.expand_dims(img,0).astype(np.float32) 
    # Get sample input data as a numpy array in a method of your choosing.
    yield [img]
    
converter.representative_dataset = representative_dataset_gen

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.experimental_new_converter = True

converter.target_spec.supported_types = [tf.int8]
converter.inference_input_type = tf.int8 
converter.inference_output_type = tf.int8 
quantized_tflite_model = converter.convert()
if tf.__version__.startswith('1.'):
    open("test153.tflite", "wb").write(quantized_tflite_model)
if tf.__version__.startswith('2.'):
    with open("test220.tflite", 'wb') as f:
        f.write(quantized_tflite_model)