如何为“ .tflite”转换过程识别Tensoflow模型的输入和输出数组?

时间:2019-08-12 06:49:52

标签: tensorflow tensorflow-lite

我正在尝试从.pb文件生成量化的.tflite模型。对于#process,我需要模型的“ input_arrays”和“ output_arrays”。

我尝试使用以下方法来识别输入数组和输出数组。但是他们都不起作用。

方法1:

import tensorflow as tf
frozen='/output/freeze/frozen_inference_graph.pb'
gf = tf.GraphDef()
gf.ParseFromString(open(frozen,'rb').read())
[n.name + '=>' +  n.op for n in gf.node if n.op in ('Softmax','Placeholder')]    
[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')]

方法2:

import tensorflow as tf
gf = tf.GraphDef()   
m_file = open('/output/freeze/frozen_inference_graph.pb','rb')
gf.ParseFromString(m_file.read())
for n in gf.node:
    print( n.name )

tflite转换查询:

import tensorflow as tf
graph_def_file = "new/barun/frozen_inference_graph.pb"
input_arrays = ['image_tensor']
output_arrays = ['BoxPredictor_5/ClassPredictor/act_quant/FakeQuantWithMinMaxVars']

converter = tf.lite.TFLiteConverter.from_frozen_graph(
  graph_def_file, input_arrays, output_arrays,input_shapes={"image_tensor":[1,300,300,3]})
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
tflite_model = converter.convert()
open("frozen_inference_graph_fd2819_2.tflite", "wb").write(tflite_model)

如何查找.pb文件的input_array和output_array?

2 个答案:

答案 0 :(得分:0)

冻结的模型是您自己的吗?在这种情况下,您可以在创建模型时命名图层。

否则,您也许可以使用某种检查工具(例如Netron)打开模型,然后查找名称。

答案 1 :(得分:0)

如果您自己创建了模型,那么input_arrays将是输入占位符张量名称的列表。我们用于推断的输出张量的名称将在output_arrays中。

如果我们可以从其他来源轻松下载模型,则可以采用一些解决方法。

  1. 在大多数TF项目中,作者在README部分中提供了输入/输出张量的详细信息。
  2. 此外,大多数项目都使用graph.get_tensor_by_name()方法来获取输入输出张量,以便他们可以将它们用于推理(尤其是tf.Session())。您可以深入探索推理文件以发现张量的名称。

如果所有其他方法都不起作用,那么正如@Silfverstrom所提到的,我们可以使用Netron来可视化图形。

可视化图形的另一种方法可以是TensorBoard。将图形写入事件文件,如

file_writer = tf.summary.FileWriter('/path/to/logs', sess.graph)

然后打开TensorBoard,

tensorboard --logdir path/to/logs

最后,只有图形可以帮助您解决问题。