输入到tflite模型的数据类型不匹配

时间:2019-03-30 16:12:17

标签: android tensorflow kotlin

我使用python创建了以下SimpleRNN模型。

python
#Omit
Xtest = np.zeros((1, 10, 1102))
#Omit
pred = model.predict(Xtest, verbose=0)[0]

如您所见,我使用整数的三维数组作为模型的输入。

然后,我将此Android模型移植为.tflite。 在下面的代码中,名为tfliteModel的部分相对应。

kotlin
Interpreter(tfliteModel!!).use { interpreter ->
    val input_onehot = Array(1) { Array(10) {Array<Int>(1102) {0} } }
    val output = Array(1) {Array<Float>(1102) { 0F } }
    //some operation like making it a one hot vector
    interpreter.run(input_onehot, output)
}

但是Android Studio抛出了这样的错误:

Caused by: java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of [[[Ljava.lang.Integer;

为什么会出现此错误? 如何将整数数组加载到模型中?

我使用this site作为参考。 但是,这是指图像,而不是NLP。

1 个答案:

答案 0 :(得分:0)

我不知道为什么,但是我通过读取float而不是int来解决它。 我使用了np.int32 ・ ・ ・ 也许numpy的int32和kotlin的Int32可能不兼容。