Tensorflow-从冻结的.pb到.tflite的转换

时间:2019-09-13 17:01:40

标签: tensorflow

我正在尝试将自定义的冻结.pb文件(saved_model.pb和saved_model.pbtxt)转换为.tflite格式,以便将其加载到Coral开发板上。尝试使用docs中的以下python代码进行转换(导出量化的GraphDef)。使用的模型是MobileNetV2。

文件名:convert.py

import tensorflow as tf

img = tf.placeholder(name="normalized_input_image_tensor", dtype=tf.float32, shape=(1, 300, 300, 3))
const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
val = img + const
out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="TFLite_Detection_PostProcess")

with tf.Session() as sess:
  converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out])
  converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
  input_arrays = converter.get_input_arrays()
  converter.quantized_input_stats = {input_arrays[0] : (0., 1.)}  # mean, std_dev
  tflite_model = converter.convert()
  open("converted_model.tflite", "wb").write(tflite_model)

对tensorflow还是一个新手,我认为我的输入和输出如下

输入 Input

输出 Output

运行上面的python代码后,输出的转换后的.tflite模型大小仅为1kb,看起来过于简单。 tflite output

这确实有几个问题。

  1. python代码中的输入是否正确?特别是,name =“ normalized_input_image_tensor”
  2. python代码中的输出是否正确?特别是,name =“ TFLite_Detection_PostProcess”
  3. .tflite输出应采用什么形状-应该更复杂吗?

谢谢!当我偶然发现时,感谢您对新手的耐心。

可视化--https://github.com/lutzroeder/Netron

编辑,发现这给了我输入名称

>>> import tensorflow as tf
>>> gf = tf.GraphDef()
>>> gf.ParseFromString(open('saved_model.pb','rb').read())
>>> [n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]

输出:['normalized_input_impage_tesnor =>占位符']

将代码更改为

img = tf.placeholder(name="Placeholder", dtype=tf.float32, shape=(1, 300, 300, 3))
const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
val = img + const
out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="MobilenetV2/Predictions/Softmax")

结果相同

1 个答案:

答案 0 :(得分:0)

tflite_convert --output_file=converted_model.tflite --graph_def_file=saved_model.pb --input_arrays=input --output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3  --input_shape=1,300,300,3 --allow_custom_ops

现在它可以工作了。