我已经使用伪量化(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)在Cifar-10上使用张量流训练了一个简单的CNN模型。然后,我使用toco生成了一个.tflite文件。现在,我想使用python解释器测试tflite模型。
由于我在训练期间使用了tf.image.per_image_standardization减去均值并除以方差。我需要对测试数据做同样的事情吗?但是,问题是,我的模型已经被tflite完全量化,并且仅将uint8数据作为输入。为了进行图像标准化,我需要将图像转换为float32。那么,在这种情况下,如何将其转换回uint8,或者甚至需要对图像进行标准化?谢谢。
答案 0 :(得分:2)
因此,事实证明,我需要对测试数据进行标准化才能获得良好的准确性。 为此,我直接将uint8输入图像输入tf.image.per_image_standardization函数。该函数会将uint8数据转换为float32,然后进行标准化(减去均值,除以std)。您可以在此处找到该函数的源代码:https://github.com/tensorflow/tensorflow/blob/r1.11/tensorflow/python/ops/image_ops_impl.py
现在,我有了标准化的float32输入图像。我所做的是编写一个量化函数以将float32图像量化回uint8。数学来自本文:https://arxiv.org/abs/1803.08607
现在,我有标准化的uint8 输入图像,然后使用tflite解释器python API测试模型。它按预期工作。