Tensorflow Lite:布尔标量作为输入张量

时间:2018-07-12 13:35:25

标签: python tensorflow tensorflow-lite

在将Tensorflow模型转换为TFLITE模型时,如何将布尔标量设置为输入?

我正在尝试将Facenet Tensor-flow模型转换为TFLITE格式,我可以将其用于Android应用程序。这个Facenet模型采用包含图像数据和布尔标量的张量作为输入:

feed_dict = { images_placeholder: images, phase_train_placeholder:False }

我遵循Tensorflow lite guide将模型转换为TFLITE格式。这是一个示例:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")

with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("converteds_model.tflite", "wb").write(tflite_model)

这是我的代码:

with tf.Graph().as_default():

        with tf.Session() as sess:

            # Load the model
            facenet.load_model(args.model)

            # Get input and output tensors
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
            #reshape as tflite does not accepts None dimension
            images_placeholder = tf.reshape(images_placeholder, [1,160,160,3])
            tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [images_placeholder, phase_train_placeholder], [embeddings])
            open("converteds_model.tflite", "wb").write(tflite_model)

程序给出一个错误信息(我想这是因为tflite尚不支持tf.bool):

File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/contrib/lite/python/convert.py", line 206, in toco_convert
    input_tensor.dtype))
ValueError: Tensors phase_train:0 not known type tf.bool

如果我将phase_train_placeholder强制转换为受支持的类型tf.int32,只是看它如何进行。然后它给出了另一个错误(我猜这是由于convert函数不接受标量而发生):

File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/contrib/lite/python/convert.py", line 217, in toco_convert
    input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
  File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/python/framework/tensor_shape.py", line 591, in __iter__
    raise ValueError("Cannot iterate over a shape with unknown rank.")
ValueError: Cannot iterate over a shape with unknown rank.

您能建议一个解决方法吗?我了解TFLite仍在开发中,但如果可能的话,我可以尝试贡献这一部分。非常感谢。

0 个答案:

没有答案