有没有办法让图片中的识别对象的颜色?

时间:2017-12-26 11:00:21

标签: java tensorflow image-recognition

我正在使用Tensorflow来识别所提供图片中的对象,遵循此tutorial并使用this repo我成功让我的程序返回图片中的对象。 例如,这是我用作输入的图片:

red-tshirt.jpg

这是我程序的输出:

enter image description here

我想要的只是获得所识别物品的颜色(最后一种情况下的红色球衣),这可能吗?

这里是代码(来自最后一个链接只有很小的变化)

/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package com.test.sec.compoment;

import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;

/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
public class ImageRecognition {
  private static void printUsage(PrintStream s) {
    final String url =
        "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
    s.println(
        "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
    s.println("to label JPEG images.");
    s.println("TensorFlow version: " + TensorFlow.version());
    s.println();
    s.println("Usage: label_image <model dir> <image file>");
    s.println();
    s.println("Where:");
    s.println("<model dir> is a directory containing the unzipped contents of the inception model");
    s.println("            (from " + url + ")");
    s.println("<image file> is the path to a JPEG image file");
  }

  public void index() {
        String modelDir = "C:/Users/Admin/Downloads/inception5h";
        String imageFile = "C:/Users/Admin/Desktop/red-tshirt.jpg";

    byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
    List<String> labels =
        readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
    byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));

    try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
      float[] labelProbabilities = executeInceptionGraph(graphDef, image);
      int bestLabelIdx = maxIndex(labelProbabilities);
      System.out.println(
          String.format("BEST MATCH: %s (%.2f%% likely)",
              labels.get(bestLabelIdx),
              labelProbabilities[bestLabelIdx] * 100f));
    }
  }

  private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
    try (Graph g = new Graph()) {
      GraphBuilder b = new GraphBuilder(g);
      // Some constants specific to the pre-trained model at:
      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
      //
      // - The model was trained with images scaled to 224x224 pixels.
      // - The colors, represented as R, G, B in 1-byte each were converted to
      //   float using (value - Mean)/Scale.
      final int H = 224;
      final int W = 224;
      final float mean = 117f;
      final float scale = 1f;

      // Since the graph is being constructed once per execution here, we can use a constant for the
      // input image. If the graph were to be re-used for multiple input images, a placeholder would
      // have been more appropriate.
      final Output<String> input = b.constant("input", imageBytes);
      final Output<Float> output =
          b.div(
              b.sub(
                  b.resizeBilinear(
                      b.expandDims(
                          b.cast(b.decodeJpeg(input, 3), Float.class),
                          b.constant("make_batch", 0)),
                      b.constant("size", new int[] {H, W})),
                  b.constant("mean", mean)),
              b.constant("scale", scale));
      try (Session s = new Session(g)) {
        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
      }
    }
  }

  private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
    try (Graph g = new Graph()) {
      g.importGraphDef(graphDef);
      try (Session s = new Session(g);
          Tensor<Float> result =
              s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
        final long[] rshape = result.shape();
        if (result.numDimensions() != 2 || rshape[0] != 1) {
          throw new RuntimeException(
              String.format(
                  "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                  Arrays.toString(rshape)));
        }
        int nlabels = (int) rshape[1];
        return result.copyTo(new float[1][nlabels])[0];
      }
    }
  }

  private static int maxIndex(float[] probabilities) {
    int best = 0;
    for (int i = 1; i < probabilities.length; ++i) {
      if (probabilities[i] > probabilities[best]) {
        best = i;
      }
    }
    return best;
  }

  private static byte[] readAllBytesOrExit(Path path) {
    try {
      return Files.readAllBytes(path);
    } catch (IOException e) {
      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
      System.exit(1);
    }
    return null;
  }

  private static List<String> readAllLinesOrExit(Path path) {
    try {
      return Files.readAllLines(path, Charset.forName("UTF-8"));
    } catch (IOException e) {
      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
      System.exit(0);
    }
    return null;
  }

  // In the fullness of time, equivalents of the methods of this class should be auto-generated from
  // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
  // like Python, C++ and Go.
  static class GraphBuilder {
    GraphBuilder(Graph g) {
      this.g = g;
    }

    Output<Float> div(Output<Float> x, Output<Float> y) {
      return binaryOp("Div", x, y);
    }

    <T> Output<T> sub(Output<T> x, Output<T> y) {
      return binaryOp("Sub", x, y);
    }

    <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
      return binaryOp3("ResizeBilinear", images, size);
    }

    <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
      return binaryOp3("ExpandDims", input, dim);
    }

    <T, U> Output<U> cast(Output<T> value, Class<U> type) {
      DataType dtype = DataType.fromClass(type);
      return g.opBuilder("Cast", "Cast")
          .addInput(value)
          .setAttr("DstT", dtype)
          .build()
          .<U>output(0);
    }

    Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
      return g.opBuilder("DecodeJpeg", "DecodeJpeg")
          .addInput(contents)
          .setAttr("channels", channels)
          .build()
          .<UInt8>output(0);
    }

    <T> Output<T> constant(String name, Object value, Class<T> type) {
      try (Tensor<T> t = Tensor.<T>create(value, type)) {
        return g.opBuilder("Const", name)
            .setAttr("dtype", DataType.fromClass(type))
            .setAttr("value", t)
            .build()
            .<T>output(0);
      }
    }
    Output<String> constant(String name, byte[] value) {
      return this.constant(name, value, String.class);
    }

    Output<Integer> constant(String name, int value) {
      return this.constant(name, value, Integer.class);
    }

    Output<Integer> constant(String name, int[] value) {
      return this.constant(name, value, Integer.class);
    }

    Output<Float> constant(String name, float value) {
      return this.constant(name, value, Float.class);
    }

    private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }

    private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }
    private Graph g;
  }
}

3 个答案:

答案 0 :(得分:6)

您正在使用预测给定图像标签的代码,即对来自某些受过训练的类别的图像进行分类,因此您不知道对象的确切像素。

所以,我建议你做以下任何事情,

  1. 使用object detector检测对象的位置并获取边界框。然后获得最多像素的颜色。
  2. 使用像this这样的像素分类(分段)来获取对象的精确像素。
  3. 注意,您可能需要手动训练对象的网络(或模型)

    编辑:

    对于Java对象检测示例,请查看为android编码的this项目,但在桌面应用程序中使用它们应该很简单。更具体地说,请查看this部分。

    您不需要同时进行对象检测和分割,但如果您愿意,我认为首先尝试使用python训练模型进行分割(上面提供了链接)然后在java中使用该模型作为对象检测模型。

    编辑2:

    我在java中添加了simple object detection client,它使用了Tensorflow对象检测API models,只是为了向您展示您可以在java中使用任何冻结模型。

    另外,请检查这个使用像素分段的漂亮repository

    enter image description here

答案 1 :(得分:0)

首先删除背景像素以仅保留对象,然后构建包含所有剩余像素的列表,然后计算平均颜色。

关于颜色检测方法,您可以查看Color Image Processing: Emerging ApplicationsColor Detection,最重要的是How we handle color detection.

此致

答案 2 :(得分:0)

使用下面的代码片段来提供RGB颜色代码,但是由于图像可能包含不同的颜色像素,因此您可以决定一个点(例如:中心)并获得具有垂直(Y)和水平(X)的RGB代码)协调。

FATAL EXCEPTION: main
    Process: com.my.app, PID: 29139
    java.lang.RuntimeException: Error receiving broadcast Intent { act=android.intent.action.DOWNLOAD_COMPLETE flg=0x10 pkg=com.my.app (has extras) } in com.my.app.APP$1@cc58ede
        at android.app.LoadedApk$ReceiverDispatcher$Args.run(LoadedApk.java:1140)
        at android.os.Handler.handleCallback(Handler.java:754)
        at android.os.Handler.dispatchMessage(Handler.java:95)
        at android.os.Looper.loop(Looper.java:163)
        at android.app.ActivityThread.main(ActivityThread.java:6321)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:880)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:770)
    Caused by: java.lang.SecurityException: COLUMN_LOCAL_FILENAME is deprecated; use ContentResolver.openFileDescriptor() instead
        at android.app.DownloadManager$CursorTranslator.getString(DownloadManager.java:1791)
        at com.my.app.App$1.onReceive(Saver.java:616)
        at android.app.LoadedApk$ReceiverDispatcher$Args.run(LoadedApk.java:1130)
        at android.os.Handler.handleCallback(Handler.java:754) 
        at android.os.Handler.dispatchMessage(Handler.java:95) 
        at android.os.Looper.loop(Looper.java:163) 
        at android.app.ActivityThread.main(ActivityThread.java:6321) 
        at java.lang.reflect.Method.invoke(Native Method) 
        at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:880) 
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:770)