TensorFlow Lite:toco_convert用于任意大小的输入张量

时间:2018-04-25 18:14:48

标签: python tensorflow

考虑将我的TensorFlow模型转换为Flatbuf格式(.tflite)。

但是,我的模型允许输入任意大小,即您可以一次分类一个项目或N个项目。当我尝试转换时,它会抛出一个错误,因为我的一个输入/输出设备属于NoneType类型。

想想TensorFlow MNIST tutorial之类的东西,在计算图中,我们的输入x的形状为[None, 784]

tflite dev guide,您可以将模型转换为FlatBuf,如下所示:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")

with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

但是,这对我不起作用。 MWE可以是:

import tensorflow as tf

img = tf.placeholder(name="inputs", dtype=tf.float32, shape=(None, 784))
out = tf.identity(inputs, name="out")


with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

错误:TypeError: __int__ returned non-int (type NoneType)

查看 tf.contrib.lite.toco_convert 文档,我们有" input_tensors:输入张量列表。使用foo.get_shape()和foo.dtype计算类型和形状。"。这就是我们失败的可能性。但我不确定是否有我应该使用的论点或允许我导出这样的模型的东西

1 个答案:

答案 0 :(得分:1)

此问题已在最新的转换器代码中解决。您可以传递第一个维度为None(第一个维度通常为批处理)的输入张量,转换器将正确处理它。

顺便说一句,在调用解释器之前,您可以调用interpreter.resize_input_tensor来更改批处理大小。