Tensorflow Lite:无法在TensorFlowLite缓冲区和ByteBuffer之间转换

时间:2019-03-04 03:20:24

标签: tensorflow tensorflow-lite

我尝试将自定义模型迁移到Android平台。张量流版本为1.12。我使用了推荐的命令行,如下所示:

tflite_convert \
  --output_file=test.tflite \
  --graph_def_file=./models/test_model.pb \
  --input_arrays=input_image \
  --output_arrays=generated_image

将.pb文件转换为tflite格式。

我已经在tensorboard中检查了.pb文件的输入张量形状:

dtype
{"type":"DT_FLOAT"}
shape
{"shape":{"dim":[{"size":474},{"size":712},{"size":3}]}}

然后,我在Android上部署tflite文件,并分配计划将模型馈入的输入ByteBuffer为:

imgData = ByteBuffer.allocateDirect(
          4 * 1 * 712 * 474 * 3);

当我在Android设备上运行模型时,应用程序崩溃,然后logcat的打印结果如下:

2019-03-04 10:31:46.822 17884-17884/android.example.com.tflitecamerademo E/AndroidRuntime: FATAL EXCEPTION: main
    Process: android.example.com.tflitecamerademo, PID: 17884
    java.lang.RuntimeException: Unable to start activity ComponentInfo{android.example.com.tflitecamerademo/com.example.android.tflitecamerademo.CameraActivity}: java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite buffer with 786432 bytes and a ByteBuffer with 4049856 bytes.

这太奇怪了,因为分配的ByteBuffer恰好是4 * 3 * 474 * 712的乘积,而tensorflow lite缓冲区不是474或712的倍数。我不知道为什么tflite模型的形状错误。

如果有人能提供解决方案,请先感谢。

3 个答案:

答案 0 :(得分:2)

您可以可视化TFLite模型来调试实际将哪些缓冲区大小分配给输入张量。

TensorFlow Lite模型可以使用 visualize.py 脚本。

如果输入张量的缓冲区大小不是您期望的大小,则转换(或提供给tflite_convert的参数)可能存在错误

答案 1 :(得分:0)

由于其他原因,我已将图像尺寸从模型创建过程中的标准224更改为299,因此我只在Android Studio项目中搜索了224,并更新了两个final参考ImageClassifier.java至299,然后我又重新营业。

答案 2 :(得分:0)

大家好, 我昨天也有类似的问题。我想提一下对我有用的解决方案。

似乎TSLite仅支持精确的正方形位图输入 喜欢 尺寸256 * 256检测正常 大小256 * 255检测不起作用引发异常

以及支持的最大尺寸 257 * 257应该是任何位图输入的最大宽度和高度

这是裁剪和调整位图大小的示例代码

private var MODEL_HEIGHT = 257
private var MODEL_WIDTH = 257

裁剪位图

val croppedBitmap = cropBitmap(bitmap)

为模型输入创建位图的缩放版本

val scaledBitmap = Bitmap.createScaledBitmap(croppedBitmap, MODEL_WIDTH, MODEL_HEIGHT, true)

https://github.com/tensorflow/examples/blob/master/lite/examples/posenet/android/app/src/main/java/org/tensorflow/lite/examples/posenet/PosenetActivity.kt#L578

裁剪位图以保持模型输入的长宽比。

private fun cropBitmap(bitmap: Bitmap): Bitmap {
val bitmapRatio = bitmap.height.toFloat() / bitmap.width
val modelInputRatio = MODEL_HEIGHT.toFloat() / MODEL_WIDTH
var croppedBitmap = bitmap

// Acceptable difference between the modelInputRatio and bitmapRatio to skip cropping.
val maxDifference = 1e-5

// Checks if the bitmap has similar aspect ratio as the required model input.
when {
  abs(modelInputRatio - bitmapRatio) < maxDifference -> return croppedBitmap
  modelInputRatio < bitmapRatio -> {
    // New image is taller so we are height constrained.
    val cropHeight = bitmap.height - (bitmap.width.toFloat() / modelInputRatio)
    croppedBitmap = Bitmap.createBitmap(
      bitmap,
      0,
      (cropHeight / 2).toInt(),
      bitmap.width,
      (bitmap.height - cropHeight).toInt()
    )
  }
  else -> {
    val cropWidth = bitmap.width - (bitmap.height.toFloat() * modelInputRatio)
    croppedBitmap = Bitmap.createBitmap(
      bitmap,
      (cropWidth / 2).toInt(),
      0,
      (bitmap.width - cropWidth).toInt(),
      bitmap.height
    )
  }
}
return croppedBitmap
}

https://github.com/tensorflow/examples/blob/master/lite/examples/posenet/android/app/src/main/java/org/tensorflow/lite/examples/posenet/PosenetActivity.kt#L451 谢谢并恭祝安康 潘卡(Pankaj)