Tensorflow-Lite Android演示版与其提供的原始模型配合使用:mobilenet_quant_v1_224.tflite。请参阅:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
他们还在这里提供其他预训练的精简模型:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md
但是,我从上面的链接中下载了一些较小的模型,例如mobilenet_v1_0.25_224.tflite,只需更改MODEL_PATH = "mobilenet_v1_0.25_224.tflite";
中的ImageClassifier.java
,就可以在演示应用中将此模型替换为原始模型。 {1}}。该应用程序崩溃:
12-11 12:52:34.222 17713-17729 /? E / AndroidRuntime:致命异常: CameraBackground 处理:android.example.com.tflitecamerademo,PID:17713 java.lang.IllegalArgumentException:无法获取输入维度。 第0个输入应该有602112个字节,但是找到150528个字节。 at org.tensorflow.lite.NativeInterpreterWrapper.getInputDims(Native 方法) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:82) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:112) 在org.tensorflow.lite.Interpreter.run(Interpreter.java:93) at com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) at com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在com.example.android.tflitecamerademo.Camera2BasicFragment.access $ 900(Camera2BasicFragment.java:69) 在com.example.android.tflitecamerademo.Camera2BasicFragment $ 5.run(Camera2BasicFragment.java:558) 在android.os.Handler.handleCallback(Handler.java:751) 在android.os.Handler.dispatchMessage(Handler.java:95) 在android.os.Looper.loop(Looper.java:154) 在android.os.HandlerThread.run(HandlerThread.java:61)
原因似乎是模型所需的输入尺寸是图像尺寸的四倍。所以我将DIM_BATCH_SIZE = 1
修改为DIM_BATCH_SIZE = 4
。现在错误是:
致命异常:CameraBackground 处理:android.example.com.tflitecamerademo,PID:18241 java.lang.IllegalArgumentException:无法转换TensorFlowLite 将类型为FLOAT32的张量转换为类型为[[B的Java对象] 兼容TensorFlowLite类型UINT8) 在org.tensorflow.lite.Tensor.copyTo(Tensor.java:36) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:122) 在org.tensorflow.lite.Interpreter.run(Interpreter.java:93) at com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) at com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在com.example.android.tflitecamerademo.Camera2BasicFragment.access $ 900(Camera2BasicFragment.java:69) 在com.example.android.tflitecamerademo.Camera2BasicFragment $ 5.run(Camera2BasicFragment.java:558) 在android.os.Handler.handleCallback(Handler.java:751) 在android.os.Handler.dispatchMessage(Handler.java:95) 在android.os.Looper.loop(Looper.java:154) 在android.os.HandlerThread.run(HandlerThread.java:61)
我的问题是如何使用简化的MobileNet tflite模型与TF-lite Android Demo一起使用。
(我实际尝试了其他的东西,比如使用提供的工具将TF冻结图转换为TF-lite模型,甚至使用与https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md中完全相同的示例代码,但转换后的tflite模型仍无法工作Android演示。)
答案 0 :(得分:4)
Tensorflow-Lite Android演示文稿中包含的ImageClassifier.java需要量化模型。截至目前,只有一种Mobilenets模型以量化形式提供: Mobilenet 1.0 224 Quant 。
要使用其他浮点模型,请从Tensorflow for Poets TF-Lite演示源交换ImageClassifier.java。这是为 float 模型编写的。 https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tflite/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
做一个差异,你会发现实施中存在几个重要的差异。
要考虑的另一个选择是将浮点模型转换为使用TOCO量化: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
答案 1 :(得分:1)
我也得到了与幼苗相同的错误。 我为Mobilenet Float模型创建了一个新的Image分类器包装器。 它现在工作正常。您可以直接在图像分类器演示中添加此类,并使用它在Camera2BasicFragment中创建分类器
classifier = new ImageClassifierFloatMobileNet(getActivity());
下面是Mobilenet Float模型的Image分类器类包装器
/**
* This classifier works with the Float MobileNet model.
*/
public class ImageClassifierFloatMobileNet extends ImageClassifier {
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
* This isn't part of the super class, because we need a primitive array here.
*/
private float[][] labelProbArray = null;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/**
* Initializes an {@code ImageClassifier}.
*
* @param activity
*/
public ImageClassifierFloatMobileNet(Activity activity) throws IOException {
super(activity);
labelProbArray = new float[1][getNumLabels()];
}
@Override
protected String getModelPath() {
// you can download this file from
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
// return "mobilenet_quant_v1_224.tflite";
return "retrained.tflite";
}
@Override
protected String getLabelPath() {
// return "labels_mobilenet_quant_v1_224.txt";
return "retrained_labels.txt";
}
@Override
public int getImageSizeX() {
return 224;
}
@Override
public int getImageSizeY() {
return 224;
}
@Override
protected int getNumBytesPerChannel() {
// the Float model uses a 4 bytes
return 4;
}
@Override
protected void addPixelValue(int val) {
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
@Override
protected float getProbability(int labelIndex) {
return labelProbArray[0][labelIndex];
}
@Override
protected void setProbability(int labelIndex, Number value) {
labelProbArray[0][labelIndex] = value.byteValue();
}
@Override
protected float getNormalizedProbability(int labelIndex) {
return labelProbArray[0][labelIndex];
}
@Override
protected void runInference() {
tflite.run(imgData, labelProbArray);
}
}