Java 无法识别 DataType 20(版本 1.15.0)

时间:2021-03-16 11:44:22

标签: java tensorflow

我正在尝试使用 Python 中经过训练的模型在 Java 中执行预测。虽然管道在 Python 中运行良好,但我正在尝试在 Java 中执行类似的预测。我将模型中的 input_1 命名为“inputTensor”,将输出命名为“outputTensor”。在 Java 中,名称已从“inputTensor”更改为“serving_default_inputTensor”。

测试图像已经归一化,只有单通道或灰度,没有RGB。

我已将 BufferedImage 转换为 Tensor。但是得到了错误 "DataType 20 is not recognized in Java (version 1.15.0)"1.15.0 是 Java 中的 tensorflow 版本。

import java.awt.Graphics2D;
import java.awt.List;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;

import javax.imageio.ImageIO;
// tensorflow libraries
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

@SuppressWarnings("unused")
public class JavaTensorflowPredict {
    public static void main(String[] args) throws IOException {
        SavedModelBundle theModel = SavedModelBundle.load("mymodel/", "serve");
        Session sess = theModel.session();
        Graph graph = theModel.graph();
        
        File inputImage = new File("../test_images/test_01.png");
        BufferedImage image1 = ImageIO.read(inputImage);
        
        System.out.println(TensorFlow.version());  
        Tensor<Float> inputTensor1 = convertBufferedImageToTensor(image1, 128, 128);
        System.out.println(inputTensor1.dataType().toString());
        Iterator<Operation> iterOpts = graph.operations();
        while (iterOpts.hasNext()) {
            Operation oprt = iterOpts.next();
            System.out.println(oprt.name()); // "15" entfernt 
        } 
        System.out.println("...predicting...\n");
        java.util.List<Tensor<?>> y = sess.runner().feed("serving_default_inputTensor", inputTensor1).fetch("outputTensor_1/kernel").run();
        System.out.println("...done...\n"); 
    }
    public static BufferedImage resize(BufferedImage img, int width, int height) {
        // obtained from https://github.com/mstritt/orbit-image-analysis
        //int type = img.getType()>0?img.getType():BufferedImage.TYPE_INT_RGB;
        int type = BufferedImage.TYPE_INT_RGB;
        //BufferedImage resizedImage = new BufferedImage(roundP2(width), roundP2(height), type);
        BufferedImage resizedImage = new BufferedImage(width, height, type);
        Graphics2D g = resizedImage.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.drawImage(img, 0, 0, width, height, null);
        g.dispose();
        return resizedImage;
    }
    public static Tensor<Float> convertBufferedImageToTensor(BufferedImage image, int targetWidth, int targetHeight) {
        // obtained from https://github.com/mstritt/orbit-image-analysis
        //if (image.getWidth()!=DESIRED_SIZE || image.getHeight()!=DESIRED_SIZE)
        {
            // also make it an RGB image
            // image = resize(image, targetWidth, targetHeight);
            // image = resize(image,image.getWidth(), image.getHeight());
        }
        int width = image.getWidth();
        int height = image.getHeight();
        Raster r = image.getRaster();
        int[] rgb = new int[1];
        //int[] data = new int[width * height];
        //image.getRGB(0, 0, width, height, data, 0, width);
        float[][][][] rgbArray = new float[1][height][width][1];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                rgb = r.getPixel(j,i,rgb);
                rgbArray[0][i][j][0] = rgb[0];
            }
        }
        return Tensor.create(rgbArray, Float.class);
    }
}

这是输出:

Exception in thread "main" java.lang.IllegalArgumentException: DataType 20 is not recognized in Java (version 1.15.0)
at org.tensorflow.DataType.fromC(DataType.java:85)
at org.tensorflow.Tensor.fromHandle(Tensor.java:540)
at org.tensorflow.Session$Runner.runHelper(Session.java:343)
at org.tensorflow.Session$Runner.run(Session.java:276)
at tftest2.main(tftest2.java:39)

注意:我使用以下脚本在 Python 中进行预测

import tensorflow as tf
import cv2
import numpy as np

the_model = tf.keras.models.laod_model("mymodel")
image1 = cv2.imread("test_01.png",0)
# convert image from (128,128) to (1,128,128,1)
image1 = np.reshape(image1, (1,)+image1.shape+(1,))
predict = the_model(image1, training = True)

2 个答案:

答案 0 :(得分:0)

答案 1 :(得分:0)

答案如下;我使用 Python Tensorflow 2.4.1 版进行训练。然后,我使用 Java 中的 TF1(版本 1.15.0)来加载模型。但是,使用 TF2(在 Java 中,tensorflow-core-platform,版本 0.3.1)可以解决问题。因为 Java tesnorflow-core-platform 0.3.0 可用于从 tensforflow 2.4.1 及更高版本在 Python 中加载模型。

相关问题