我从Tensorflow 2.0教程中训练了Pix2Pix生成器,并通过tflite将其导出:
converter = tf.lite.TFLiteConverter.from_keras_model(generator)
tflite_model = converter.convert()
open("facades.tflite", "wb").write(tflite_model)
不幸的是,当我尝试推断时,我似乎遇到了来自 tf.keras.layers.BatchNormalization 的问题。
首先,推理结果仅返回Nan值。这可以通过禁用fused实现来解决。
第二,BatchNormalization层的行为根据我们是处于训练还是预测而有所不同。教程explicitly states以 training = True 模式进行预测。我不知道如何使用tflite模型来做到这一点。
一种解决方案是用 InstanceNormalization 替换 BatchNormalization 层,该实例可在 tensorflow_addons 中找到。 到tflite的转换完全没有问题,但是推断仍然有问题。 当我在解释器上调用invoke时,由于返回SEGFAULT而崩溃。根据堆栈调用,它将来自InstanceNormalization层的SquaredDifference运算符。
有人能设法将此TensorFlow 2.0模型转换为tflite并正确推断吗?怎么样 ?谢谢。
PS:我更喜欢使用BatchNormalization的解决方案,因为它是Keras中的标准层,因此也可以与TensorFlow javascript一起使用。