Java Tensorflow + Keras等效于model.predict()

时间:2020-06-07 06:08:12

标签: java tensorflow keras tensorflow-serving

在python中,您只需将numpy数组传递给predict()即可从模型中获取预测。将Java与SavedModelBundle一起使用等效于什么?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 

Java

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?

3 个答案:

答案 0 :(得分:2)

TensorFlow Python自动将您的NumPy数组转换为tf.Tensor。在TensorFlow Java中,您可以直接操纵张量。

现在SavedModelBundle没有predict方法。您需要使用SessionRunner并输入输入张量来获取并运行会话。

例如,基于下一代TF Java(https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,自您的代码示例以来,我在这里对x_test_maxabs进行了大量假设)没有清楚说明它的来源):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}

如果不确定图形中输入/输出张量的名称是什么,可以通过查看签名定义以编程方式获取:

model.metaGraphDef().getSignatureDefMap().get("serving_default")

答案 1 :(得分:1)

您可以尝试Deep Java Library (DJL)

DJL在内部使用Tensorflow Java并提供高级API以使其易于推理:

Criteria<Image, Classifications> criteria =
    Criteria.builder()
        .setTypes(Image.class, Classifications.class)
        .optModelUrls("https://example.com/squeezenet.zip")
        .optTranslator(ImageClassificationTranslator
               .builder().addTransform(new ToTensor()).build())
        .build();

try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
        Predictor<Image, Classification> predictor = model.newPredictor()) {
    Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
    Classification result = predictor.predict(image);
}


签出github仓库:https://github.com/awslabs/djl

有一个博客文章:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6

可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md

答案 2 :(得分:0)

0.3.1 API 中:

val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")

val inputTensor = TFloat32.tesnorOf(..)

val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()

获得任何子类型的结果张量后,您可以迭代其值。在我的例子中,我有一个形状为 TFloat32(1, 56),所以我找到了 result.get(0, idx)

的最大值