无法在TFLite张量之间复制-形状不合法

时间:2018-11-06 15:06:54

标签: java android tensorflow tensorflow-lite

我目前正在制作自己的模型,并且在tensorflow-for-poets-2演示中都可以正常工作。我在不同的文件夹中训练了多张图片,应用程序识别出了它。

现在,我想在对象周围显示一个边框。我找到了一个示例here

我的问题是,当我添加自己的tflite模型时,我的应用返回以下错误:

E/AndroidRuntime: FATAL EXCEPTION: inference
Process: org.tensorflow.lite.demo, PID: 3495
java.lang.IllegalArgumentException: Cannot copy between a TensorFlowLite tensor with shape [1, 6] and a Java object with shape [1, 10, 4].
    at org.tensorflow.lite.Tensor.throwExceptionIfTypeIsIncompatible(Tensor.java:240)
    at org.tensorflow.lite.Tensor.copyTo(Tensor.java:116)
    at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:157)
    at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:229)
    at org.tensorflow.demo.TFLiteObjectDetectionAPIModel.recognizeImage(TFLiteObjectDetectionAPIModel.java:194)
    at org.tensorflow.demo.DetectorActivity$3.run(DetectorActivity.java:247)
    at android.os.Handler.handleCallback(Handler.java:789)
    at android.os.Handler.dispatchMessage(Handler.java:98)
    at android.os.Looper.loop(Looper.java:164)
    at android.os.HandlerThread.run(HandlerThread.java:65)

我如何训练他们:

python3 scripts/retrain.py \
  --bottleneck_dir=bottlenecks \
  --how_many_training_steps=500 \
  --model_dir=inception \
  --output_graph=tf_files/retrained_graph.pb \
  --output_labels=tf_files/retrained_labels.txt \
  --image_dir=tf_files/ \
  --architecture mobilenet_1.0_224

正在生成:

toco \
  --input_format=TENSORFLOW_GRAPHDEF \
  --input_file=tf_files/retrained_graph.pb \
  --output_format=TFLITE \
  --output_file=tf_files/optimized_graph.lite \
  --inference_type=FLOAT \
  --inference_input_type=FLOAT \
  --input_arrays=input \
  --output_arrays=final_result \
  --input_shapes=1,224,224,3\
  --mean_values=128 \
  --std_values=128 \
  --default_ranges_min=0 \
  --default_ranges_max=6

DetectorActivity.java

// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 224; // 300
private static final boolean TF_OD_API_IS_QUANTIZED = false; // true
private static final String TF_OD_API_MODEL_FILE = "optimized_graph.lite"; //detect.tflite
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/retrained_labels.txt";

2 个答案:

答案 0 :(得分:0)

您当前的用例是什么?与Tensorflow诗人一样吗?该错误很明显是您的tensorflow模型输出形状与应用程序不匹配。该应用程序的用例可能与您的用例不同。

答案 1 :(得分:0)

我来了。 只是因为您准备使用模型

tfLite.run(imgData,outputScores);

私有静态最终整数NUM_DETECTIONS = 6;

它在我身边有效