TFLite中缺少广播的解决方法

时间:2019-09-23 20:11:46

标签: tensorflow

我想运行一个TFLite模型,该模型要求我产生3d输出(示例代码是产生错误的最小示例)。是否有一个等效于collector_nd的张量流不会将尺寸减小一倍?

我尝试浏览文档中我能想到的相关功能,但是找不到一个好的选择。

import tensorflow.compat.v1 as tf
import numpy as np

tf.disable_v2_behavior()
initial_input = tf.placeholder(dtype=tf.float32, shape=(None,5,1024))
cap_i = tf.gather_nd(initial_input, [[0,1]]) #[0,2],[0,3],[0,4],[0,5]
cap_i_broadcast = tf.broadcast_to(cap_i, [1,5,1024])
cap_iT = tf.transpose(cap_i_broadcast, perm=[0,2,1])

sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.io.write_graph(sess.graph_def, '', 'train.pbtxt')
converter = tf.lite.TFLiteConverter.from_session(sess, [initial_input], [cap_iT])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('converted_model.tflite', "wb").write(tflite_model)
sess.close()

标准TensorFlow Lite运行时不支持模型中的某些运算符,并且TensorFlow无法识别。如果您有针对他们的自定义实现,则可以使用--allow_custom_ops或通过在调用tf.lite.TFLiteConverter()时设置allow_custom_ops = True来禁用此错误。以下是您正在使用的内置运算符的列表:GATHER_ND,TRANSPOSE。这是您需要自定义实现的运算符的列表:BroadcastTo。

1 个答案:

答案 0 :(得分:0)

下面的代码提供了一种解决方案,该方法使用具有降低维数的分步切片,然后重新整形以获取正确的维数。

import tensorflow.compat.v1 as tf
import numpy as np

tf.disable_v2_behavior()
initial_input = tf.placeholder(dtype=tf.float32, shape=(None,5,1024))
cap_i = tf.strided_slice(initial_input, [0,0,0], [0,5,1024], [1,1,1], 
shrink_axis_mask=1)
cap_i_reshaped =tf.reshape(cap_i,[1,5,1024])
cap_iT = tf.transpose(cap_i_reshaped, perm=[0,2,1])

sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.io.write_graph(sess.graph_def, '', 'train.pbtxt')
converter = tf.lite.TFLiteConverter.from_session(sess, [initial_input], 
[cap_iT])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, 
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('converted_model.tflite', "wb").write(tflite_model)
sess.close()

TFLite以前支持切片,但只有strided_slice。