定义tf.lite.TocoConverter

时间:2019-04-04 12:04:40

标签: python tensorflow tensorflow-lite

我正在尝试将此CPM-TF模型转换为TFLite,但是要使用TocoConverter,我需要指定输入和输出张量。 https://github.com/timctho/convolutional-pose-machines-tensorflow

我运行了包含的run_freeze_model.py,并得到了cpm_hand_frozen.pb(GraphDef?)文件。

在这篇文章中,我复制了代码片段,用于使用已知的输入和输出转换ProtoBuf文件。但是通过查看模型定义代码,我很难找到输入和输出的正确答案。 Tensorflow Convert pb file to TFLITE using python

import tensorflow as tf
import numpy as np
from config import FLAGS

path_to_frozen_graphdef_pb = 'frozen_models/cpm_hand_frozen.pb'


def main(argv):
    input_tensors = [1, FLAGS.input_size, FLAGS.input_size, 3]
    output_tensors = np.zeros(FLAGS.num_of_joints)
    frozen_graph_def = tf.GraphDef()

    with open(path_to_frozen_graphdef_pb, 'rb') as f:
        frozen_graph_def.ParseFromString(f.read())

    tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

if __name__ == '__main__':
    tf.app.run()

我是Tensorflow的新手,但我认为输入应定义为
[1, FLAGS.input_size, FLAGS.input_size, 3] 在这里找到:https://github.com/timctho/convolutional-pose-machines-tensorflow/blob/master/models/nets/cpm_hand.py#L23

不确定1代表什么,但None无效,我想其他参数是图像尺寸和颜色通道。

但是,使用该输入,它将返回错误: AttributeError: 'int' object has no attribute 'dtype'

除了应该是数组之外,我对输出应该是什么一无所知。


更新1

浏览TF文档,看来我需要将输入定义为张量(显而易见!)。 https://www.tensorflow.org/lite/convert/python_api

input_tensors = tf.placeholder(name="img", dtype=tf.float32, shape=(1,FLAGS.input_size, FLAGS.input_size, 3))

这不会返回错误,但是我仍然需要弄清楚输入是否正确以及输出应该是什么样。


更新2 好了,所以我终于有了这个代码片段来吐出tflite模型

def tflite_converter():
    graph_def_file = os.path.join('frozen_models', '{}_frozen.pb'.format('cpm_hand'))

    input_arrays = ['input_placeholer']
    output_arrays = [FLAGS.output_node_names]

    converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
    tflite_model = converter.convert()

    open('{}.tflite'.format('cpm_hand'), 'wb').write(tflite_model)

我希望我做的正确。我将尝试在Android上对该模型进行推断。 我也确实认为输入张量input_placeholder中存在拼写错误。它似乎可以在代码本身中得到纠正,但是通过从预先训练的模型中打印出所有节点名称,可以看到拼写input_placeholer

可以在此处看到节点名称:https://github.com/timctho/convolutional-pose-machines-tensorflow/issues/59


我的设置是:

Ubuntu 18.04
CUDA 9.1和cuDNN 7.0
Python 3.6.5
Tensorflow GPU 1.6

推理的工作原理很吸引人,因此设置本身应该没有问题。

0 个答案:

没有答案