Java中的Yolo探测器

时间:2017-11-02 14:39:24

标签: tensorflow

我正在尝试用Java实现Yolo检测器(不是Android,而是桌面 - Windows / Ubuntu)

Android已有Yolo探测器:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android

我从该项目复制了一些java类,将它们添加到IntelliJ IDEA并编辑它们

我甚至复制并编辑了TensorFlowInferenceInterface.java来自jar文件(tensorflow库 - libandroid_tensorflow_inference_java.jar)for Android

我几乎成功地让它发挥作用。

结果

enter image description here

控制台输出(类标题,置信度,x,y,宽度,高度):

  

汽车0.8836523,148 166 270 267

     

car 0.51286024,147 174 268 274

     

car 0.05002968,174 164 275 262

所以似乎正确检测到汽车,确定正确 x y 宽度高度

有问题

可能出现什么问题?

这是我项目的完整代码

主要

public class Main implements Classifier {

    private static final int BLOCK_SIZE = 32;
    private static final int MAX_RESULTS = 3;
    private static final int NUM_CLASSES = 20;
    private static final int NUM_BOXES_PER_BLOCK = 5;
    private static final int INPUT_SIZE = 416;
    private static final String inputName = "input";
    private static final String outputName = "output";

    // Pre-allocated buffers.
    private static int[] intValues;
    private static float[] floatValues;
    private static String[] outputNames;

    // yolo 2
    private static final double[] ANCHORS = { 1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071 };

    // tiny yolo
    //private static final double[] ANCHORS = { 1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52 };

    private static final String[] LABELS = {
            "aeroplane",
            "bicycle",
            "bird",
            "boat",
            "bottle",
            "bus",
            "car",
            "cat",
            "chair",
            "cow",
            "diningtable",
            "dog",
            "horse",
            "motorbike",
            "person",
            "pottedplant",
            "sheep",
            "sofa",
            "train",
            "tvmonitor"
    };

    private static TensorFlowInferenceInterface inferenceInterface;

    public static void main(String[] args) {

        //String modelDir = "/home/user/JavaProjects/TensorFlowJavaProject"; // Ubuntu
        String modelAndTestImagesDir = "D:\\JavaProjects\\TensorFlowJavaProject"; // Windows
        String imageFile = modelAndTestImagesDir + File.separator + "0.png"; // 416x416 test image

        outputNames = outputName.split(",");
        floatValues = new float[INPUT_SIZE * INPUT_SIZE * 3];

        // yolo 2 voc
        inferenceInterface = new TensorFlowInferenceInterface(Paths.get(modelAndTestImagesDir, "yolo-voc.pb"));

        // tiny yolo voc
        //inferenceInterface = new TensorFlowInferenceInterface(Paths.get(modelAndTestImagesDir, "graph-tiny-yolo-voc.pb"));

        BufferedImage img;

        try {
            img = ImageIO.read(new File(imageFile));

            BufferedImage convertedImg = new BufferedImage(img.getWidth(), img.getHeight(), BufferedImage.TYPE_INT_RGB);
            convertedImg.getGraphics().drawImage(img, 0, 0, null);

            intValues = ((DataBufferInt) convertedImg.getRaster().getDataBuffer()).getData() ;

            List<Classifier.Recognition> recognitions = recognizeImage();

            System.out.println("Result length " + recognitions.size());

            Graphics2D graphics = convertedImg.createGraphics();

            for (Recognition recognition : recognitions) {
                RectF rectF = recognition.getLocation();
                System.out.println(recognition.getTitle() + " " + recognition.getConfidence() + ", " +
                        (int) rectF.x + " " + (int) rectF.y + " " + (int) rectF.width + " " + ((int) rectF.height));
                Stroke stroke = graphics.getStroke();
                graphics.setStroke(new BasicStroke(3));
                graphics.setColor(Color.green);
                graphics.drawRoundRect((int) rectF.x, (int) rectF.y, (int) rectF.width, (int) rectF.height, 5, 5);
                graphics.setStroke(stroke);
            }

            graphics.dispose();
            ImageIcon icon=new ImageIcon(convertedImg);
            JFrame frame=new JFrame();
            frame.setLayout(new FlowLayout());
            frame.setSize(convertedImg.getWidth(),convertedImg.getHeight());
            JLabel lbl=new JLabel();
            frame.setTitle("Java (Win/Ubuntu), Tensorflow & Yolo");
            lbl.setIcon(icon);
            frame.add(lbl);
            frame.setVisible(true);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

        } catch (IOException e) {
            e.printStackTrace();
        }


    }

    private static List<Classifier.Recognition> recognizeImage() {

        for (int i = 0; i < intValues.length; ++i) {
            floatValues[i * 3 + 0] = ((intValues[i] >> 16) & 0xFF) / 255.0f;
            floatValues[i * 3 + 1] = ((intValues[i] >> 8) & 0xFF) / 255.0f;
            floatValues[i * 3 + 2] = (intValues[i] & 0xFF) / 255.0f;
        }
        inferenceInterface.feed(inputName, floatValues, 1, INPUT_SIZE, INPUT_SIZE, 3);

        inferenceInterface.run(outputNames, false);

        final int gridWidth = INPUT_SIZE / BLOCK_SIZE;
        final int gridHeight = INPUT_SIZE / BLOCK_SIZE;

        final float[] output = new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];

        inferenceInterface.fetch(outputNames[0], output);

        // Find the best detections.
        final PriorityQueue<Classifier.Recognition> pq =
                new PriorityQueue<>(
                        1,
                        new Comparator<Classifier.Recognition>() {
                            @Override
                            public int compare(final Classifier.Recognition lhs, final Classifier.Recognition rhs) {
                                // Intentionally reversed to put high confidence at the head of the queue.
                                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                            }
                        });

        for (int y = 0; y < gridHeight; ++y) {
            for (int x = 0; x < gridWidth; ++x) {
                for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
                    final int offset =
                            (gridWidth * (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5))) * y
                                    + (NUM_BOXES_PER_BLOCK * (NUM_CLASSES + 5)) * x
                                    + (NUM_CLASSES + 5) * b;

                    final float xPos = (x + expit(output[offset + 0])) * BLOCK_SIZE;
                    final float yPos = (y + expit(output[offset + 1])) * BLOCK_SIZE;

                    final float w = (float) (Math.exp(output[offset + 2]) * ANCHORS[2 * b + 0]) * BLOCK_SIZE;
                    final float h = (float) (Math.exp(output[offset + 3]) * ANCHORS[2 * b + 1]) * BLOCK_SIZE;

                    final RectF rect =
                            new RectF(
                                    Math.max(0, xPos - w / 2),
                                    Math.max(0, yPos - h / 2),
                                    Math.min(INPUT_SIZE - 1, xPos + w / 2),
                                    Math.min(INPUT_SIZE - 1, yPos + h / 2));

                    final float confidence = expit(output[offset + 4]);

                    int detectedClass = -1;
                    float maxClass = 0;

                    final float[] classes = new float[NUM_CLASSES];
                    for (int c = 0; c < NUM_CLASSES; ++c) {
                        classes[c] = output[offset + 5 + c];
                    }
                    softmax(classes);

                    for (int c = 0; c < NUM_CLASSES; ++c) {
                        if (classes[c] > maxClass) {
                            detectedClass = c;
                            maxClass = classes[c];
                        }
                    }

                    final float confidenceInClass = maxClass * confidence;
                    if (confidenceInClass > 0.01) {
                        pq.add(new Classifier.Recognition(detectedClass, LABELS[detectedClass], confidenceInClass, rect));
                    }
                }
            }
        }

        final ArrayList<Classifier.Recognition> recognitions = new ArrayList<>();
        for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
            recognitions.add(pq.poll());
        }
        return recognitions;

    }

    private static float expit(final float x) {
        return (float) (1. / (1. + Math.exp(-x)));
    }

    private static void softmax(final float[] vals) {
        float max = Float.NEGATIVE_INFINITY;
        for (final float val : vals) {
            max = Math.max(max, val);
        }
        float sum = 0.0f;
        for (int i = 0; i < vals.length; ++i) {
            vals[i] = (float) Math.exp(vals[i] - max);
            sum += vals[i];
        }
        for (int i = 0; i < vals.length; ++i) {
            vals[i] = vals[i] / sum;
        }
    }

    public void close() {
        inferenceInterface.close();
    }
}

TensorFlowInferenceInterface

public class TensorFlowInferenceInterface {
    private static final String TAG = "TensorFlowInferenceInterface";
    private final Graph g;
    private final Session sess;
    private Runner runner;
    private List<String> feedNames = new ArrayList();
    private List<Tensor> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList();
    private List<Tensor> fetchTensors = new ArrayList();
    private RunStats runStats;

    public TensorFlowInferenceInterface(Path path) {
        this.prepareNativeRuntime();
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();

        try {
            this.loadGraph(readAllBytesOrExit(path), this.g);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }

    public void run(String[] var1) {
        this.run(var1, false);
    }

    public void run(String[] var1, boolean var2) {
        this.closeFetches();
        String[] var3 = var1;
        int var4 = var1.length;

        for (int var5 = 0; var5 < var4; ++var5) {
            String var6 = var3[var5];
            this.fetchNames.add(var6);
            TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
            this.runner.fetch(var7.name, var7.outputIndex);
        }

        try {
            if (var2) {
                Run var13 = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                this.fetchTensors = var13.outputs;
                if (this.runStats == null) {
                    this.runStats = new RunStats();
                }

                this.runStats.add(var13.metadata);
            } else {
                this.fetchTensors = this.runner.run();
            }
        } catch (RuntimeException var11) {
            throw var11;
        } finally {
            this.closeFeeds();
            this.runner = this.sess.runner();
        }

    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String var1) {
        Operation var2 = this.g.operation(var1);
        if (var2 == null) {
            throw new RuntimeException("Node '" + var1 + "' does not exist in model '");
        } else {
            return var2;
        }
    }

    public String getStatString() {
        return this.runStats == null ? "" : this.runStats.summary();
    }

    public void close() {
        this.closeFeeds();
        this.closeFetches();
        this.sess.close();
        this.g.close();
        if (this.runStats != null) {
            this.runStats.close();
        }

        this.runStats = null;
    }

    protected void finalize() throws Throwable {
        try {
            this.close();
        } finally {
            super.finalize();
        }

    }

    public void feed(String var1, float[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
    }

    public void fetch(String var1, float[] var2) {
        this.fetch(var1, FloatBuffer.wrap(var2));
    }

    public void fetch(String var1, FloatBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    private void prepareNativeRuntime() {
        System.out.println("TensorFlowInferenceInterface Checking to see if TensorFlow native methods are already loaded");

        try {
            new RunStats();
            System.out.println("TensorFlowInferenceInterface TensorFlow native methods already loaded");
        } catch (UnsatisfiedLinkError var4) {
            System.out.println("TensorFlowInferenceInterface TensorFlow native methods not found, attempting to load via tensorflow_inference");
        }

    }

    private void loadGraph(byte[] var1, Graph var2) throws IOException {
        try {
            var2.importGraphDef(var1);
        } catch (IllegalArgumentException var7) {
            throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
        }
    }

    private void addFeed(String var1, Tensor var2) {
        TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
        this.runner.feed(var3.name, var3.outputIndex, var2);
        this.feedNames.add(var1);
        this.feedTensors.add(var2);
    }

    private Tensor getTensor(String var1) {
        int var2 = 0;

        for (Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
            String var4 = (String) var3.next();
            if (var4.equals(var1)) {
                return this.fetchTensors.get(var2);
            }
        }

        throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
    }

    private void closeFeeds() {
        Iterator var1 = this.feedTensors.iterator();

        while (var1.hasNext()) {
            Tensor var2 = (Tensor) var1.next();
            var2.close();
        }

        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        Iterator var1 = this.fetchTensors.iterator();

        while (var1.hasNext()) {
            Tensor var2 = (Tensor) var1.next();
            var2.close();
        }

        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorFlowInferenceInterface.TensorId parse(String var0) {
            TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
            int var2 = var0.lastIndexOf(58);
            if (var2 < 0) {
                var1.outputIndex = 0;
                var1.name = var0;
                return var1;
            } else {
                try {
                    var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
                    var1.name = var0.substring(0, var2);
                } catch (NumberFormatException var4) {
                    var1.outputIndex = 0;
                    var1.name = var0;
                }

                return var1;
            }
        }
    }
}

分类

public interface Classifier {

  public class Recognition {

    private final int id;

    private final String title;

    private final Float confidence;

    private RectF location;

    public Recognition(
            final int id, final String title, final Float confidence, final RectF location) {
      this.id = id;
      this.title = title;
      this.confidence = confidence;
      this.location = location;
    }

    public int getId() {
      return id;
    }

    public String getTitle() {
      return title;
    }

    public Float getConfidence() {
      return confidence;
    }

    public RectF getLocation() {
      return new RectF(location);
    }

    public void setLocation(RectF location) {
      this.location = location;
    }
  }

  void close();
}

RUNSTATS

public class RunStats implements AutoCloseable {
    private long nativeHandle = allocate();
    private static byte[] fullTraceRunOptions = new byte[]{8, 3};

    public static byte[] runOptions() {
        return fullTraceRunOptions;
    }

    public RunStats() {
    }

    public void close() {
        if(this.nativeHandle != 0L) {
            delete(this.nativeHandle);
        }

        this.nativeHandle = 0L;
    }

    public synchronized void add(byte[] var1) {
        add(this.nativeHandle, var1);
    }

    public synchronized String summary() {
        return summary(this.nativeHandle);
    }

    private static native long allocate();

    private static native void delete(long var0);

    private static native void add(long var0, byte[] var2);

    private static native String summary(long var0);
}

RectF

public class RectF {

    public float getX() {
        return x;
    }

    public void setX(float x) {
        this.x = x;
    }

    public float getY() {
        return y;
    }

    public void setY(float y) {
        this.y = y;
    }

    public float getWidth() {
        return width;
    }

    public void setWidth(float width) {
        this.width = width;
    }

    public float getHeight() {
        return height;
    }

    public void setHeight(float height) {
        this.height = height;
    }

    public float x = 0f;
    public float y = 0f;
    public float width = 0f;
    public float height = 0f;


    RectF(RectF rectF) {
        this.x = rectF.x;
        this.y = rectF.y;
        this.width = rectF.width;
        this.height = rectF.height;
    }

    RectF(float x, float y, float w, float h) {
        this.x = x;
        this.y = y;
        this.width = w;
        this.height = h;
    }
}

1 个答案:

答案 0 :(得分:3)

已解决(我混淆了x,y,width,heightleft,top,right,bottom

这是更新的RectF

public class RectF {
    public float left;
    public float top;
    public float right;
    public float bottom;

    public RectF() {}

    public RectF(float left, float top, float right, float bottom) {
        this.left = left;
        this.top = top;
        this.right = right;
        this.bottom = bottom;
    }

    public RectF(RectF r) {
        if (r == null) {
            left = top = right = bottom = 0.0f;
        } else {
            left = r.left;
            top = r.top;
            right = r.right;
            bottom = r.bottom;
        }
    }

    public String toString() {
        return "RectF(" + left + ", " + top + ", "
                + right + ", " + bottom + ")";
    }

    public final float width() {
        return right - left;
    }

    public final float height() {
        return bottom - top;
    }

    public final float centerX() {
        return (left + right) * 0.5f;
    }

    public final float centerY() {
        return (top + bottom) * 0.5f;
    }
}

然后

graphics.drawRoundRect((int) rectF.left, (int) rectF.top, (int) rectF.width(), (int) rectF.height(), 5, 5);

enter image description here

P.S。对于TensorFlow 1.4.0,这里是一个更新的TensorFlowInferenceInterface类:

public class TensorFlowInferenceInterface {
    private final Graph g;
    private final Session sess;
    private Runner runner;
    private List<String> feedNames = new ArrayList();
    private List<Tensor<?>> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList();
    private List<Tensor<?>> fetchTensors = new ArrayList();
    private RunStats runStats;

    public TensorFlowInferenceInterface(Path path) {
        this.prepareNativeRuntime();
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();

        try {
            this.loadGraph(readAllBytesOrExit(path), this.g);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }

    public void run(String[] var1) {
        this.run(var1, false);
    }

    public void run(String[] var1, boolean var2) {
        this.closeFetches();
        String[] var3 = var1;
        int var4 = var1.length;

        for(int var5 = 0; var5 < var4; ++var5) {
            String var6 = var3[var5];
            this.fetchNames.add(var6);
            TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
            this.runner.fetch(var7.name, var7.outputIndex);
        }

        try {
            if(var2) {
                Run var13 = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                this.fetchTensors = var13.outputs;
                if(this.runStats == null) {
                    this.runStats = new RunStats();
                }

                this.runStats.add(var13.metadata);
            } else {
                this.fetchTensors = this.runner.run();
            }
        } catch (RuntimeException var11) {
            System.out.println("Failed to run TensorFlow inference with inputs:["
                    + String.join(", ", this.feedNames)
                    + "], outputs:[" + String.join(", ", this.fetchNames) + "]");
            throw var11;
        } finally {
            this.closeFeeds();
            this.runner = this.sess.runner();
        }

    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String var1) {
        Operation var2 = this.g.operation(var1);
        if(var2 == null) {
            throw new RuntimeException("Node '" + var1 + "' does not exist in model '");
        } else {
            return var2;
        }
    }

    public String getStatString() {
        return this.runStats == null?"":this.runStats.summary();
    }

    public void close() {
        this.closeFeeds();
        this.closeFetches();
        this.sess.close();
        this.g.close();
        if(this.runStats != null) {
            this.runStats.close();
        }

        this.runStats = null;
    }

    protected void finalize() throws Throwable {
        try {
            this.close();
        } finally {
            super.finalize();
        }

    }

    public void feed(String var1, float[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
    }

    public void feed(String var1, int[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2)));
    }

    public void feed(String var1, long[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2)));
    }

    public void feed(String var1, double[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2)));
    }

    public void feed(String var1, byte[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(UInt8.class, var3, ByteBuffer.wrap(var2)));
    }

    public void feedString(String var1, byte[] var2) {
        this.addFeed(var1, Tensors.create(var2));
    }

    public void feedString(String var1, byte[][] var2) {
        this.addFeed(var1, Tensors.create(var2));
    }

    public void feed(String var1, FloatBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, IntBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, LongBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, DoubleBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, ByteBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(UInt8.class, var3, var2));
    }

    public void fetch(String var1, float[] var2) {
        this.fetch(var1, FloatBuffer.wrap(var2));
    }

    public void fetch(String var1, int[] var2) {
        this.fetch(var1, IntBuffer.wrap(var2));
    }

    public void fetch(String var1, long[] var2) {
        this.fetch(var1, LongBuffer.wrap(var2));
    }

    public void fetch(String var1, double[] var2) {
        this.fetch(var1, DoubleBuffer.wrap(var2));
    }

    public void fetch(String var1, byte[] var2) {
        this.fetch(var1, ByteBuffer.wrap(var2));
    }

    public void fetch(String var1, FloatBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, IntBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, LongBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, DoubleBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, ByteBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    private void prepareNativeRuntime() {
        System.out.println("Checking to see if TensorFlow native methods are already loaded");

        try {
            new RunStats();
            System.out.println("TensorFlow native methods already loaded");
        } catch (UnsatisfiedLinkError var4) {
            System.out.println("TensorFlow native methods not found, attempting to load via tensorflow_inference");

         /*   try {
                System.loadLibrary("tensorflow_inference");
                System.out.println("Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
            } catch (UnsatisfiedLinkError var3) {
                throw new RuntimeException("Native TF methods not found; check that the correct native libraries are present in the APK.");
            }*/
        }

    }

    private void loadGraph(byte[] var1, Graph var2) throws IOException {
        long var3 = System.currentTimeMillis();
        try {
            var2.importGraphDef(var1);
        } catch (IllegalArgumentException var7) {
            throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
        }


        long var5 = System.currentTimeMillis();
        System.out.println("Model load took " + (var5 - var3) + "ms, TensorFlow version: " + TensorFlow.version());
    }

    private void addFeed(String var1, Tensor<?> var2) {
        TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
        this.runner.feed(var3.name, var3.outputIndex, var2);
        this.feedNames.add(var1);
        this.feedTensors.add(var2);
    }

    private Tensor<?> getTensor(String var1) {
        int var2 = 0;

        for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
            String var4 = (String)var3.next();
            if(var4.equals(var1)) {
                return (Tensor)this.fetchTensors.get(var2);
            }
        }

        throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
    }

    private void closeFeeds() {
        Iterator var1 = this.feedTensors.iterator();

        while(var1.hasNext()) {
            Tensor var2 = (Tensor)var1.next();
            var2.close();
        }

        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        Iterator var1 = this.fetchTensors.iterator();

        while(var1.hasNext()) {
            Tensor var2 = (Tensor)var1.next();
            var2.close();
        }

        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorFlowInferenceInterface.TensorId parse(String var0) {
            TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
            int var2 = var0.lastIndexOf(58);
            if(var2 < 0) {
                var1.outputIndex = 0;
                var1.name = var0;
                return var1;
            } else {
                try {
                    var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
                    var1.name = var0.substring(0, var2);
                } catch (NumberFormatException var4) {
                    var1.outputIndex = 0;
                    var1.name = var0;
                }

                return var1;
            }
        }
    }
}