转换后,TFLite模型中缺少图层

时间:2020-07-23 10:49:51

标签: python c++ tensorflow tensorflow-lite

所以我使用tensorflow和python训练了一个模型,现在我正尝试在C ++程序中使用它。我使用此代码将模型转换为tflite(转换过程中没有任何错误):

model.load_weights('training_weights.h5', by_name=True)
model.save('saved_model/model')

converter = tf.lite.TFLiteConverter.from_saved_model("saved_model/model")
tflite_model = converter.convert()
with tf.io.gfile.GFile('model.tflite', 'wb') as f:
    f.write(tflite_model)

然后我用C ++加载模型,并尝试使用它,但是输出与我的网络输出不匹配。使用python,网络的最后一层是:

X = Lambda(lambda x: K.expand_dims(x, axis=2), name='deconv_expand_dim')(input_tensor)
X = Conv2DTranspose(filters, (kernel_size, 1), strides=(strides, 1), padding=padding, 
                    activation=activation, kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer, name='deconv')(X)
X = Lambda(lambda x: K.squeeze(x, axis=2), name='deconv_reduce_dim')(X)

以及在C ++中加载的模型的最后一层是:

1227 model/deconv_expand_dim/ExpandDims;StatefulPartitionedCall/model/deconv_expand_dim/ExpandDims
1228 model/deconv/Shape;StatefulPartitionedCall/model/deconv/Shape
1229 model/deconv/strided_slice;StatefulPartitionedCall/model/deconv/strided_slice1
1230 model/deconv/strided_slice_1;StatefulPartitionedCall/model/deconv/strided_slice_1
1231 model/deconv/strided_slice_2;StatefulPartitionedCall/model/deconv/strided_slice_22
1232 model/deconv/stack;StatefulPartitionedCall/model/deconv/stack
1233 model/deconv/conv2d_transpose;StatefulPartitionedCall/model/deconv/conv2d_transpose1
1234 model/deconv/BiasAdd;StatefulPartitionedCall/model/deconv/BiasAdd
1235 Identity

我只是对解释器-> tensors_size()进行了for循环以列出层。问题:

  • interpreter-> tensors_size()返回一个较大的数字,该数字因一项测试而异于另一项测试(大约1300)。
  • interpreter-> outputs()[0] 返回Identity的索引(1235)。
  • 在C ++中,我有 deconv_expand_dim deconv 层,但是我没有 deconv_reduce_dim 层。
  • interpreter-> tensor(outputIndex)-> dims-> size 等于0,这是有问题的,因为这是网络的输出。我在转换过程中错过了一步吗?如何获得有效的输出?

非常感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

我想我找到了它不起作用的原因。如您在最后一层中看到的,我使用运算符:

K.squeeze(x, axis=2)

似乎尚不支持指定的轴(“ tf.squeeze-只要不提供轴。”,最后访问时间:2020年7月28日,link)。

这就是为什么我在C ++中没有与Python中相同的输出的原因,Tensorflow Lite目前不支持我的网络。