在将Tensorflow
模型转换为TFLITE
模型时,如何将布尔标量设置为输入?
我正在尝试将Facenet Tensor-flow模型转换为TFLITE
格式,我可以将其用于Android应用程序。这个Facenet模型采用包含图像数据和布尔标量的张量作为输入:
feed_dict = { images_placeholder: images, phase_train_placeholder:False }
我遵循Tensorflow lite guide将模型转换为TFLITE格式。这是一个示例:
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
open("converteds_model.tflite", "wb").write(tflite_model)
这是我的代码:
with tf.Graph().as_default():
with tf.Session() as sess:
# Load the model
facenet.load_model(args.model)
# Get input and output tensors
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
#reshape as tflite does not accepts None dimension
images_placeholder = tf.reshape(images_placeholder, [1,160,160,3])
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [images_placeholder, phase_train_placeholder], [embeddings])
open("converteds_model.tflite", "wb").write(tflite_model)
程序给出一个错误信息(我想这是因为tflite
尚不支持tf.bool
):
File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/contrib/lite/python/convert.py", line 206, in toco_convert
input_tensor.dtype))
ValueError: Tensors phase_train:0 not known type tf.bool
如果我将phase_train_placeholder
强制转换为受支持的类型tf.int32
,只是看它如何进行。然后它给出了另一个错误(我猜这是由于convert函数不接受标量而发生):
File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/contrib/lite/python/convert.py", line 217, in toco_convert
input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
File "/.../facenet/tensorflow1.8/lib/python3.4/site-packages/tensorflow/python/framework/tensor_shape.py", line 591, in __iter__
raise ValueError("Cannot iterate over a shape with unknown rank.")
ValueError: Cannot iterate over a shape with unknown rank.
您能建议一个解决方法吗?我了解TFLite仍在开发中,但如果可能的话,我可以尝试贡献这一部分。非常感谢。