张量流识别对象问题训练模型

时间:2018-09-23 13:09:02

标签: java tensorflow

我正在创建一个识别对象的程序,但是我遇到了一些问题。 我使用以下代码:

 File file = new File("Java.jpg");
        Image image = new Image(file.toURI().toString());
   imga.setImage(image);


        String  modelpath= "C:\\Users\\AngruAdminAlex\\Documents\\inception_dec_2015";

  float verison=(float) 1.0;
                System.out.println("Opening: " + modelpath);
                modelselected = true;
                graphDef = readAllBytesOrExit(Paths.get(modelpath, "tensorflow_inception_graph.pb"));
                labels = readAllLinesOrExit(Paths.get(modelpath, "imagenet_comp_graph_label_strings.txt"));

           File file1 = new File ("Java.jpg");
          String imagepath="Java.jpg";
                System.out.println("Image Path: " + imagepath);
                BufferedImage img = ImageIO.read(file1);


        //read image

        System.out.println("Reading complete.");
            byte[] imageBytes = readAllBytesOrExit(Paths.get(imagepath));
byte data[] = imageBytes;
FileOutputStream out = new FileOutputStream("buffer.rawfile");
out.write(data);
out.close();
            try (Tensor image1 = Tensor.create(imageBytes)) {
                float[] labelProbabilities = executeInceptionGraph(graphDef, image1);
                int bestLabelIdx = maxIndex(labelProbabilities);

                System.out.println(
                         String.format(
                                "BEST MATCH: %s (%.2f likely)",
                                labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));

results.setText(  String.format(
                                "BEST MATCH: %s (%.2f likely)",
                                labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
        }
    }

当我使用预训练的模型初始版本v3时,一切都很好,但是当我使用训练的模型时,出现错误:

Opening: C:\Users\AngruAdminAlex\Documents\1.10

Image Path: C:\Users\AngruAdminAlex\Pictures\Test\20180907_213807.jpg
Exception in thread "AWT-EventQueue-0" java.lang.IllegalArgumentException: NodeDef mentions attr 'dilations' not in Op<name=Conv2D; signature=input:T, filter:T -> output:T; attr=T:type,allowed=[DT_HALF, DT_FLOAT, DT_DOUBLE]; attr=strides:list(int); attr=use_cudnn_on_gpu:bool,default=true; attr=padding:string,allowed=["SAME", "VALID"]; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]>; NodeDef: module_apply_default/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="VALID", strides=[1, 2, 2, 1], use_cudnn_on_gpu=true](module_apply_default/hub_input/Sub, module_apply_default/InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D/ReadVariableOp). (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).
    at org.tensorflow.Graph.importGraphDef(Native Method)
    at org.tensorflow.Graph.importGraphDef(Graph.java:130)
    at org.tensorflow.Graph.importGraphDef(Graph.java:114)

模型链接:model

2 个答案:

答案 0 :(得分:0)

请查看以下引发类似问题的链接:

https://github.com/tensorflow/models/issues/4093

似乎您可能必须将Tensorflow更新到1.8或更高版本。

答案 1 :(得分:0)

如何为此代码训练InceptionV3

  1. 在代码更改中:

    Tensor result =  s.runner().feed("DecodeJpeg/contents",image).fetch("softmax").run().get(0)) 
    

    Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) 
    
  2. here

  3. 下载Tensorflw本机库
  4. 第三步是解压缩.zip之类的.jar文件并将lib文件复制到您的项目文件夹中。
  5. 只需从here获取retrain.py
  6. 打开命令提示符并输入命令:

    python (path to retrain.py)  --image_dir (path to image) -- output_graph (path to graph files(.pb) -- output_labels(path to label(.txt)
    
  7. 查找.pb文件,将其重命名为label.pb并在代码中使用