我正在尝试用Java实现Yolo检测器(不是Android,而是桌面 - Windows / Ubuntu)
Android已有Yolo探测器:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
我从该项目复制了一些java类,将它们添加到IntelliJ IDEA并编辑它们
我甚至复制并编辑了TensorFlowInferenceInterface.java
来自jar文件(tensorflow库 - libandroid_tensorflow_inference_java.jar
)for Android
我几乎成功地让它发挥作用。
结果
控制台输出(类标题,置信度,x,y,宽度,高度):
汽车0.8836523,148 166 270 267
car 0.51286024,147 174 268 274
car 0.05002968,174 164 275 262
所以似乎正确检测到汽车,确定正确 x , y 但宽度和高度
有问题可能出现什么问题?
这是我项目的完整代码
主要
public class Main implements Classifier {
private static final int BLOCK_SIZE = 32;
private static final int MAX_RESULTS = 3;
private static final int NUM_CLASSES = 20;
private static final int NUM_BOXES_PER_BLOCK = 5;
private static final int INPUT_SIZE = 416;
private static final String inputName = "input";
private static final String outputName = "output";
// Pre-allocated buffers.
private static int[] intValues;
private static float[] floatValues;
private static String[] outputNames;
// yolo 2
private static final double[] ANCHORS = { 1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071 };
// tiny yolo
//private static final double[] ANCHORS = { 1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52 };
private static final String[] LABELS = {
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor"
};
private static TensorFlowInferenceInterface inferenceInterface;
public static void main(String[] args) {
//String modelDir = "/home/user/JavaProjects/TensorFlowJavaProject"; // Ubuntu
String modelAndTestImagesDir = "D:\\JavaProjects\\TensorFlowJavaProject"; // Windows
String imageFile = modelAndTestImagesDir + File.separator + "0.png"; // 416x416 test image
outputNames = outputName.split(",");
floatValues = new float[INPUT_SIZE * INPUT_SIZE * 3];
// yolo 2 voc
inferenceInterface = new TensorFlowInferenceInterface(Paths.get(modelAndTestImagesDir, "yolo-voc.pb"));
// tiny yolo voc
//inferenceInterface = new TensorFlowInferenceInterface(Paths.get(modelAndTestImagesDir, "graph-tiny-yolo-voc.pb"));
BufferedImage img;
try {
img = ImageIO.read(new File(imageFile));
BufferedImage convertedImg = new BufferedImage(img.getWidth(), img.getHeight(), BufferedImage.TYPE_INT_RGB);
convertedImg.getGraphics().drawImage(img, 0, 0, null);
intValues = ((DataBufferInt) convertedImg.getRaster().getDataBuffer()).getData() ;
List<Classifier.Recognition> recognitions = recognizeImage();
System.out.println("Result length " + recognitions.size());
Graphics2D graphics = convertedImg.createGraphics();
for (Recognition recognition : recognitions) {
RectF rectF = recognition.getLocation();
System.out.println(recognition.getTitle() + " " + recognition.getConfidence() + ", " +
(int) rectF.x + " " + (int) rectF.y + " " + (int) rectF.width + " " + ((int) rectF.height));
Stroke stroke = graphics.getStroke();
graphics.setStroke(new BasicStroke(3));
graphics.setColor(Color.green);
graphics.drawRoundRect((int) rectF.x, (int) rectF.y, (int) rectF.width, (int) rectF.height, 5, 5);
graphics.setStroke(stroke);
}
graphics.dispose();
ImageIcon icon=new ImageIcon(convertedImg);
JFrame frame=new JFrame();
frame.setLayout(new FlowLayout());
frame.setSize(convertedImg.getWidth(),convertedImg.getHeight());
JLabel lbl=new JLabel();
frame.setTitle("Java (Win/Ubuntu), Tensorflow & Yolo");
lbl.setIcon(icon);
frame.add(lbl);
frame.setVisible(true);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
} catch (IOException e) {
e.printStackTrace();
}
}
private static List<Classifier.Recognition> recognizeImage() {
for (int i = 0; i < intValues.length; ++i) {
floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
}
inferenceInterface.feed(inputName, floatValues, 1, INPUT_SIZE, INPUT_SIZE, 3);
inferenceInterface.run(outputNames, false);
final int gridWidth = INPUT_SIZE / BLOCK_SIZE;
final int gridHeight = INPUT_SIZE / BLOCK_SIZE;
final float[] output = new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
inferenceInterface.fetch(outputNames[0], output);
// Find the best detections.
final PriorityQueue<Classifier.Recognition> pq =
new PriorityQueue<>(
1,
new Comparator<Classifier.Recognition>() {
@Override
public int compare(final Classifier.Recognition lhs, final Classifier.Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int y = 0; y < gridHeight; ++y) {
for (int x = 0; x < gridWidth; ++x) {
for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
final int offset =
(gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
+ (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
+ (NUM_CLASSES + 5) * b;
final float xPos = (x + expit(output[offset + 0])) * BLOCK_SIZE;
final float yPos = (y + expit(output[offset + 1])) * BLOCK_SIZE;
final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * BLOCK_SIZE;
final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * BLOCK_SIZE;
final RectF rect =
new RectF(
Math.max(0, xPos - w / 2),
Math.max(0, yPos - h / 2),
Math.min(INPUT_SIZE - 1, xPos + w / 2),
Math.min(INPUT_SIZE - 1, yPos + h / 2));
final float confidence = expit(output[offset + 4]);
int detectedClass = -1;
float maxClass = 0;
final float[] classes = new float[NUM_CLASSES];
for (int c = 0; c < NUM_CLASSES; ++c) {
classes[c] = output[offset + 5 + c];
}
softmax(classes);
for (int c = 0; c < NUM_CLASSES; ++c) {
if (classes[c] > maxClass) {
detectedClass = c;
maxClass = classes[c];
}
}
final float confidenceInClass = maxClass * confidence;
if (confidenceInClass > 0.01) {
pq.add(new Classifier.Recognition(detectedClass, LABELS[detectedClass], confidenceInClass, rect));
}
}
}
}
final ArrayList<Classifier.Recognition> recognitions = new ArrayList<>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
return recognitions;
}
private static float expit(final float x) {
return (float) (1. / (1. + Math.exp(-x)));
}
private static void softmax(final float[] vals) {
float max = Float.NEGATIVE_INFINITY;
for (final float val : vals) {
max = Math.max(max, val);
}
float sum = 0.0f;
for (int i = 0; i < vals.length; ++i) {
vals[i] = (float) Math.exp(vals[i] - max);
sum += vals[i];
}
for (int i = 0; i < vals.length; ++i) {
vals[i] = vals[i] / sum;
}
}
public void close() {
inferenceInterface.close();
}
}
TensorFlowInferenceInterface
public class TensorFlowInferenceInterface {
private static final String TAG = "TensorFlowInferenceInterface";
private final Graph g;
private final Session sess;
private Runner runner;
private List<String> feedNames = new ArrayList();
private List<Tensor> feedTensors = new ArrayList();
private List<String> fetchNames = new ArrayList();
private List<Tensor> fetchTensors = new ArrayList();
private RunStats runStats;
public TensorFlowInferenceInterface(Path path) {
this.prepareNativeRuntime();
this.g = new Graph();
this.sess = new Session(this.g);
this.runner = this.sess.runner();
try {
this.loadGraph(readAllBytesOrExit(path), this.g);
} catch (IOException e) {
e.printStackTrace();
}
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
public void run(String[] var1) {
this.run(var1, false);
}
public void run(String[] var1, boolean var2) {
this.closeFetches();
String[] var3 = var1;
int var4 = var1.length;
for (int var5 = 0; var5 < var4; ++var5) {
String var6 = var3[var5];
this.fetchNames.add(var6);
TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
this.runner.fetch(var7.name, var7.outputIndex);
}
try {
if (var2) {
Run var13 = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
this.fetchTensors = var13.outputs;
if (this.runStats == null) {
this.runStats = new RunStats();
}
this.runStats.add(var13.metadata);
} else {
this.fetchTensors = this.runner.run();
}
} catch (RuntimeException var11) {
throw var11;
} finally {
this.closeFeeds();
this.runner = this.sess.runner();
}
}
public Graph graph() {
return this.g;
}
public Operation graphOperation(String var1) {
Operation var2 = this.g.operation(var1);
if (var2 == null) {
throw new RuntimeException("Node '" + var1 + "' does not exist in model '");
} else {
return var2;
}
}
public String getStatString() {
return this.runStats == null ? "" : this.runStats.summary();
}
public void close() {
this.closeFeeds();
this.closeFetches();
this.sess.close();
this.g.close();
if (this.runStats != null) {
this.runStats.close();
}
this.runStats = null;
}
protected void finalize() throws Throwable {
try {
this.close();
} finally {
super.finalize();
}
}
public void feed(String var1, float[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
}
public void fetch(String var1, float[] var2) {
this.fetch(var1, FloatBuffer.wrap(var2));
}
public void fetch(String var1, FloatBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
private void prepareNativeRuntime() {
System.out.println("TensorFlowInferenceInterface Checking to see if TensorFlow native methods are already loaded");
try {
new RunStats();
System.out.println("TensorFlowInferenceInterface TensorFlow native methods already loaded");
} catch (UnsatisfiedLinkError var4) {
System.out.println("TensorFlowInferenceInterface TensorFlow native methods not found, attempting to load via tensorflow_inference");
}
}
private void loadGraph(byte[] var1, Graph var2) throws IOException {
try {
var2.importGraphDef(var1);
} catch (IllegalArgumentException var7) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
}
}
private void addFeed(String var1, Tensor var2) {
TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
this.runner.feed(var3.name, var3.outputIndex, var2);
this.feedNames.add(var1);
this.feedTensors.add(var2);
}
private Tensor getTensor(String var1) {
int var2 = 0;
for (Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
String var4 = (String) var3.next();
if (var4.equals(var1)) {
return this.fetchTensors.get(var2);
}
}
throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
Iterator var1 = this.feedTensors.iterator();
while (var1.hasNext()) {
Tensor var2 = (Tensor) var1.next();
var2.close();
}
this.feedTensors.clear();
this.feedNames.clear();
}
private void closeFetches() {
Iterator var1 = this.fetchTensors.iterator();
while (var1.hasNext()) {
Tensor var2 = (Tensor) var1.next();
var2.close();
}
this.fetchTensors.clear();
this.fetchNames.clear();
}
private static class TensorId {
String name;
int outputIndex;
private TensorId() {
}
public static TensorFlowInferenceInterface.TensorId parse(String var0) {
TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
int var2 = var0.lastIndexOf(58);
if (var2 < 0) {
var1.outputIndex = 0;
var1.name = var0;
return var1;
} else {
try {
var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
var1.name = var0.substring(0, var2);
} catch (NumberFormatException var4) {
var1.outputIndex = 0;
var1.name = var0;
}
return var1;
}
}
}
}
分类
public interface Classifier {
public class Recognition {
private final int id;
private final String title;
private final Float confidence;
private RectF location;
public Recognition(
final int id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
}
public int getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public RectF getLocation() {
return new RectF(location);
}
public void setLocation(RectF location) {
this.location = location;
}
}
void close();
}
RUNSTATS
public class RunStats implements AutoCloseable {
private long nativeHandle = allocate();
private static byte[] fullTraceRunOptions = new byte[]{8, 3};
public static byte[] runOptions() {
return fullTraceRunOptions;
}
public RunStats() {
}
public void close() {
if(this.nativeHandle != 0L) {
delete(this.nativeHandle);
}
this.nativeHandle = 0L;
}
public synchronized void add(byte[] var1) {
add(this.nativeHandle, var1);
}
public synchronized String summary() {
return summary(this.nativeHandle);
}
private static native long allocate();
private static native void delete(long var0);
private static native void add(long var0, byte[] var2);
private static native String summary(long var0);
}
RectF
public class RectF {
public float getX() {
return x;
}
public void setX(float x) {
this.x = x;
}
public float getY() {
return y;
}
public void setY(float y) {
this.y = y;
}
public float getWidth() {
return width;
}
public void setWidth(float width) {
this.width = width;
}
public float getHeight() {
return height;
}
public void setHeight(float height) {
this.height = height;
}
public float x = 0f;
public float y = 0f;
public float width = 0f;
public float height = 0f;
RectF(RectF rectF) {
this.x = rectF.x;
this.y = rectF.y;
this.width = rectF.width;
this.height = rectF.height;
}
RectF(float x, float y, float w, float h) {
this.x = x;
this.y = y;
this.width = w;
this.height = h;
}
}
答案 0 :(得分:3)
已解决(我混淆了x,y,width,height
和left,top,right,bottom
)
这是更新的RectF
:
public class RectF {
public float left;
public float top;
public float right;
public float bottom;
public RectF() {}
public RectF(float left, float top, float right, float bottom) {
this.left = left;
this.top = top;
this.right = right;
this.bottom = bottom;
}
public RectF(RectF r) {
if (r == null) {
left = top = right = bottom = 0.0f;
} else {
left = r.left;
top = r.top;
right = r.right;
bottom = r.bottom;
}
}
public String toString() {
return "RectF(" + left + ", " + top + ", "
+ right + ", " + bottom + ")";
}
public final float width() {
return right - left;
}
public final float height() {
return bottom - top;
}
public final float centerX() {
return (left + right) * 0.5f;
}
public final float centerY() {
return (top + bottom) * 0.5f;
}
}
然后
graphics.drawRoundRect((int) rectF.left, (int) rectF.top, (int) rectF.width(), (int) rectF.height(), 5, 5);
P.S。对于TensorFlow 1.4.0,这里是一个更新的TensorFlowInferenceInterface
类:
public class TensorFlowInferenceInterface {
private final Graph g;
private final Session sess;
private Runner runner;
private List<String> feedNames = new ArrayList();
private List<Tensor<?>> feedTensors = new ArrayList();
private List<String> fetchNames = new ArrayList();
private List<Tensor<?>> fetchTensors = new ArrayList();
private RunStats runStats;
public TensorFlowInferenceInterface(Path path) {
this.prepareNativeRuntime();
this.g = new Graph();
this.sess = new Session(this.g);
this.runner = this.sess.runner();
try {
this.loadGraph(readAllBytesOrExit(path), this.g);
} catch (IOException e) {
e.printStackTrace();
}
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
public void run(String[] var1) {
this.run(var1, false);
}
public void run(String[] var1, boolean var2) {
this.closeFetches();
String[] var3 = var1;
int var4 = var1.length;
for(int var5 = 0; var5 < var4; ++var5) {
String var6 = var3[var5];
this.fetchNames.add(var6);
TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
this.runner.fetch(var7.name, var7.outputIndex);
}
try {
if(var2) {
Run var13 = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
this.fetchTensors = var13.outputs;
if(this.runStats == null) {
this.runStats = new RunStats();
}
this.runStats.add(var13.metadata);
} else {
this.fetchTensors = this.runner.run();
}
} catch (RuntimeException var11) {
System.out.println("Failed to run TensorFlow inference with inputs:["
+ String.join(", ", this.feedNames)
+ "], outputs:[" + String.join(", ", this.fetchNames) + "]");
throw var11;
} finally {
this.closeFeeds();
this.runner = this.sess.runner();
}
}
public Graph graph() {
return this.g;
}
public Operation graphOperation(String var1) {
Operation var2 = this.g.operation(var1);
if(var2 == null) {
throw new RuntimeException("Node '" + var1 + "' does not exist in model '");
} else {
return var2;
}
}
public String getStatString() {
return this.runStats == null?"":this.runStats.summary();
}
public void close() {
this.closeFeeds();
this.closeFetches();
this.sess.close();
this.g.close();
if(this.runStats != null) {
this.runStats.close();
}
this.runStats = null;
}
protected void finalize() throws Throwable {
try {
this.close();
} finally {
super.finalize();
}
}
public void feed(String var1, float[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
}
public void feed(String var1, int[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2)));
}
public void feed(String var1, long[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2)));
}
public void feed(String var1, double[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2)));
}
public void feed(String var1, byte[] var2, long... var3) {
this.addFeed(var1, Tensor.create(UInt8.class, var3, ByteBuffer.wrap(var2)));
}
public void feedString(String var1, byte[] var2) {
this.addFeed(var1, Tensors.create(var2));
}
public void feedString(String var1, byte[][] var2) {
this.addFeed(var1, Tensors.create(var2));
}
public void feed(String var1, FloatBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, IntBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, LongBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, DoubleBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, ByteBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(UInt8.class, var3, var2));
}
public void fetch(String var1, float[] var2) {
this.fetch(var1, FloatBuffer.wrap(var2));
}
public void fetch(String var1, int[] var2) {
this.fetch(var1, IntBuffer.wrap(var2));
}
public void fetch(String var1, long[] var2) {
this.fetch(var1, LongBuffer.wrap(var2));
}
public void fetch(String var1, double[] var2) {
this.fetch(var1, DoubleBuffer.wrap(var2));
}
public void fetch(String var1, byte[] var2) {
this.fetch(var1, ByteBuffer.wrap(var2));
}
public void fetch(String var1, FloatBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, IntBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, LongBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, DoubleBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, ByteBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
private void prepareNativeRuntime() {
System.out.println("Checking to see if TensorFlow native methods are already loaded");
try {
new RunStats();
System.out.println("TensorFlow native methods already loaded");
} catch (UnsatisfiedLinkError var4) {
System.out.println("TensorFlow native methods not found, attempting to load via tensorflow_inference");
/* try {
System.loadLibrary("tensorflow_inference");
System.out.println("Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
} catch (UnsatisfiedLinkError var3) {
throw new RuntimeException("Native TF methods not found; check that the correct native libraries are present in the APK.");
}*/
}
}
private void loadGraph(byte[] var1, Graph var2) throws IOException {
long var3 = System.currentTimeMillis();
try {
var2.importGraphDef(var1);
} catch (IllegalArgumentException var7) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
}
long var5 = System.currentTimeMillis();
System.out.println("Model load took " + (var5 - var3) + "ms, TensorFlow version: " + TensorFlow.version());
}
private void addFeed(String var1, Tensor<?> var2) {
TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
this.runner.feed(var3.name, var3.outputIndex, var2);
this.feedNames.add(var1);
this.feedTensors.add(var2);
}
private Tensor<?> getTensor(String var1) {
int var2 = 0;
for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
String var4 = (String)var3.next();
if(var4.equals(var1)) {
return (Tensor)this.fetchTensors.get(var2);
}
}
throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
Iterator var1 = this.feedTensors.iterator();
while(var1.hasNext()) {
Tensor var2 = (Tensor)var1.next();
var2.close();
}
this.feedTensors.clear();
this.feedNames.clear();
}
private void closeFetches() {
Iterator var1 = this.fetchTensors.iterator();
while(var1.hasNext()) {
Tensor var2 = (Tensor)var1.next();
var2.close();
}
this.fetchTensors.clear();
this.fetchNames.clear();
}
private static class TensorId {
String name;
int outputIndex;
private TensorId() {
}
public static TensorFlowInferenceInterface.TensorId parse(String var0) {
TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
int var2 = var0.lastIndexOf(58);
if(var2 < 0) {
var1.outputIndex = 0;
var1.name = var0;
return var1;
} else {
try {
var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
var1.name = var0.substring(0, var2);
} catch (NumberFormatException var4) {
var1.outputIndex = 0;
var1.name = var0;
}
return var1;
}
}
}
}