计数使用tensorflowlite android检测到的对象数。这是tflite给出结果的代码:
public class TensorFlowImageClassifier implements Classifier {
private static final int MAX_RESULTS = 3;
private static final int BATCH_SIZE = 1;
private static final int PIXEL_SIZE = 3;
private static final float THRESHOLD = 0.1f;
private Interpreter interpreter;
private int inputSize;
private List<String> labelList;
@Override
public List<Recognition> recognizeImage(Bitmap bitmap) {
ByteBuffer byteBuffer = convertBitmapToByteBuffer(bitmap);
byte[][] result = new byte[1][labelList.size()];
interpreter.run(byteBuffer, result);
return getSortedResult(result);
}
private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * inputSize * inputSize * PIXEL_SIZE);
byteBuffer.order(ByteOrder.nativeOrder());
int[] intValues = new int[inputSize * inputSize];
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
final int val = intValues[pixel++];
byteBuffer.put((byte) ((val >> 16) & 0xFF));
byteBuffer.put((byte) ((val >> 8) & 0xFF));
byteBuffer.put((byte) (val & 0xFF));
}
}
return byteBuffer;
}
private List<Recognition> getSortedResult(byte[][] labelProbArray) {
PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < labelList.size(); ++i) {
float confidence = (labelProbArray[0][i] & 0xff) / 255.0f;
if (confidence > THRESHOLD) {
pq.add(new Recognition("" + i,
labelList.size() > i ? labelList.get(i) : "unknown",
confidence));
}
}
final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
return recognitions;
}
}