我正在尝试将此CPM-TF模型转换为TFLite,但是要使用TocoConverter,我需要指定输入和输出张量。 https://github.com/timctho/convolutional-pose-machines-tensorflow
我运行了包含的run_freeze_model.py
,并得到了cpm_hand_frozen.pb
(GraphDef?)文件。
在这篇文章中,我复制了代码片段,用于使用已知的输入和输出转换ProtoBuf文件。但是通过查看模型定义代码,我很难找到输入和输出的正确答案。 Tensorflow Convert pb file to TFLITE using python
import tensorflow as tf
import numpy as np
from config import FLAGS
path_to_frozen_graphdef_pb = 'frozen_models/cpm_hand_frozen.pb'
def main(argv):
input_tensors = [1, FLAGS.input_size, FLAGS.input_size, 3]
output_tensors = np.zeros(FLAGS.num_of_joints)
frozen_graph_def = tf.GraphDef()
with open(path_to_frozen_graphdef_pb, 'rb') as f:
frozen_graph_def.ParseFromString(f.read())
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)
if __name__ == '__main__':
tf.app.run()
我是Tensorflow的新手,但我认为输入应定义为
[1, FLAGS.input_size, FLAGS.input_size, 3]
在这里找到:https://github.com/timctho/convolutional-pose-machines-tensorflow/blob/master/models/nets/cpm_hand.py#L23
不确定1代表什么,但None无效,我想其他参数是图像尺寸和颜色通道。
但是,使用该输入,它将返回错误:
AttributeError: 'int' object has no attribute 'dtype'
除了应该是数组之外,我对输出应该是什么一无所知。
浏览TF文档,看来我需要将输入定义为张量(显而易见!)。 https://www.tensorflow.org/lite/convert/python_api
input_tensors = tf.placeholder(name="img", dtype=tf.float32, shape=(1,FLAGS.input_size, FLAGS.input_size, 3))
这不会返回错误,但是我仍然需要弄清楚输入是否正确以及输出应该是什么样。
def tflite_converter():
graph_def_file = os.path.join('frozen_models', '{}_frozen.pb'.format('cpm_hand'))
input_arrays = ['input_placeholer']
output_arrays = [FLAGS.output_node_names]
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open('{}.tflite'.format('cpm_hand'), 'wb').write(tflite_model)
我希望我做的正确。我将尝试在Android上对该模型进行推断。
我也确实认为输入张量input_placeholder
中存在拼写错误。它似乎可以在代码本身中得到纠正,但是通过从预先训练的模型中打印出所有节点名称,可以看到拼写input_placeholer
。
可以在此处看到节点名称:https://github.com/timctho/convolutional-pose-machines-tensorflow/issues/59
Ubuntu 18.04
CUDA 9.1和cuDNN 7.0
Python 3.6.5
Tensorflow GPU 1.6
推理的工作原理很吸引人,因此设置本身应该没有问题。