我正在尝试将预训练模型(使用python)加载到java项目中。
问题是
rd /?
代码
2>nul
训练和保存模型的python代码
Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
at org.tensorflow.Tensor.scalarFloat(Native Method)
at org.tensorflow.Tensor.floatValue(Tensor.java:279)
似乎我已经成功加载了模型和图形,因为
float[] arr=context.csvintarr(context.getPlayer(playerId));
float[][] martix={arr};
try (Graph g=model.graph()){
try(Session s=model.session()){
Tensor y=s.runner().feed("input/input", Tensor.create(martix))
.fetch("out/predict").run().get(0);
logger.info("a {}",y.floatValue());
}
}
成功打印出名称。
答案 0 :(得分:1)
错误消息表明输出张量不是浮点值标量,因此它可能是更高维度张量(向量,矩阵)。
您可以使用System.out.println(y.toString())
或专门使用y.shape()
来了解张量的形状。在您的Python代码中,这对应于y.shape
。
对于非标量,使用y.copyTo
获取浮点数组(对于向量)或浮点数组数组(对于矩阵)等。
例如:
System.out.println(y);
// If the above printed something like:
// "FLOAT tensor with shape [1]"
// then you can get the values using:
float[] vector = y.copyTo(new float[1]);
// If the shape was something like [2, 3]
// then you can get the values using:
float[][] matrix = y.copyTo(new float[2][3]);
有关floatValue()
vs copyTo
vs writeTo
的更多信息,请参阅Tensor
javadoc。
希望有所帮助。