我一般对对象检测API和TensorFlow还是陌生的。我遵循了this教程,最后制作了frozen_inference_graph.pb
。我想在手机上运行该对象检测模型,据我所知,我需要将其转换为.tflite(如果没有任何意义,请使用lmk)。
当我尝试在此处使用此标准代码进行转换时:
import tensorflow as tf
graph = 'pathtomygraph'
input_arrays = ['image_tensor']
output_arrays = ['all_class_predictions_with_background']
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
它抛出一个错误,说:
ValueError:仅在第一个维度中不支持。张量 'image_tensor'具有无效的形状'[None,None,None,3]'
这是我在互联网上发现的常见错误,在搜索了许多线程之后,我尝试为代码提供额外的参数:
converter = tf.lite.TFLiteConverter.from_frozen_graph(
graph, input_arrays, output_arrays,input_shapes={"image_tensor":[1,600,600,3]})
现在看起来像这样:
import tensorflow as tf
graph = 'pathtomygraph'
input_arrays = ['image_tensor']
output_arrays = ['all_class_predictions_with_background']
converter = tf.lite.TFLiteConverter.from_frozen_graph(
graph, input_arrays, output_arrays,input_shapes={"image_tensor":[1,600,600,3]})
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
此可以工作,但最后会抛出另一个错误,说:
检查失败:array.data_type == array.final_data_type数组 “ image_tensor”的实际和最终数据类型不匹配 (data_type = uint8,final_data_type = float)。致命错误:已中止
我知道我的输入张量具有uint8的数据类型,这可能导致不匹配。我的问题是,这是处理事物的正确方法吗? (我想在手机上运行我的模型)。如果是,该如何解决错误? :/
非常感谢您。
答案 0 :(得分:0)
将模型输入(image_tensor
占位符)更改为数据类型tf.float32
。