在python中,您只需将numpy数组传递给predict()
即可从模型中获取预测。将Java与SavedModelBundle
一起使用等效于什么?
model = tf.keras.models.Sequential([
# layers go here
])
model.compile(...)
model.fit(x_train, y_train)
predictions = model.predict(x_test_maxabs) # <= This line
SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
答案 0 :(得分:2)
TensorFlow Python自动将您的NumPy数组转换为tf.Tensor
。在TensorFlow Java中,您可以直接操纵张量。
现在SavedModelBundle
没有predict
方法。您需要使用SessionRunner
并输入输入张量来获取并运行会话。
例如,基于下一代TF Java(https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,自您的代码示例以来,我在这里对x_test_maxabs
进行了大量假设)没有清楚说明它的来源):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
如果不确定图形中输入/输出张量的名称是什么,可以通过查看签名定义以编程方式获取:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
答案 1 :(得分:1)
您可以尝试Deep Java Library (DJL)。
DJL在内部使用Tensorflow Java并提供高级API以使其易于推理:
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls("https://example.com/squeezenet.zip")
.optTranslator(ImageClassificationTranslator
.builder().addTransform(new ToTensor()).build())
.build();
try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
Predictor<Image, Classification> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
Classification result = predictor.predict(image);
}
签出github仓库:https://github.com/awslabs/djl
有一个博客文章:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6
可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md
答案 2 :(得分:0)
在 0.3.1
API 中:
val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")
val inputTensor = TFloat32.tesnorOf(..)
val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()
获得任何子类型的结果张量后,您可以迭代其值。在我的例子中,我有一个形状为 TFloat32
的 (1, 56)
,所以我找到了 result.get(0, idx)