将Pix2Pix生成器导出到tflite模型

时间:2020-05-14 17:47:56

标签: python tensorflow2.0 tensorflow-lite tensorflow.js

我从Tensorflow 2.0教程中训练了Pix2Pix生成器,并通过tflite将其导出:

converter = tf.lite.TFLiteConverter.from_keras_model(generator)
tflite_model = converter.convert()
open("facades.tflite", "wb").write(tflite_model)

不幸的是,当我尝试推断时,我似乎遇到了来自 tf.keras.layers.BatchNormalization 的问题。

首先,推理结果仅返回Nan值。这可以通过禁用fused实现来解决。

第二,BatchNormalization层的行为根据我们是处于训练还是预测而有所不同。教程explicitly states training = True 模式进行预测。我不知道如何使用tflite模型来做到这一点。

一种解决方案是用 InstanceNormalization 替换 BatchNormalization 层,该实例可在 tensorflow_addons 中找到。 到tflite的转换完全没有问题,但是推断仍然有问题。 当我在解释器上调用invoke时,由于返回SEGFAULT而崩溃。根据堆栈调用,它将来自InstanceNormalization层的SquaredDifference运算符。

有人能设法将此TensorFlow 2.0模型转换为tflite并正确推断吗?怎么样 ?谢谢。

PS:我更喜欢使用BatchNormalization的解决方案,因为它是Keras中的标准层,因此也可以与TensorFlow javascript一起使用。

0 个答案:

没有答案