tensorflow java api err:java.lang.IllegalStateException:Tensor不是标量

时间:2017-08-20 17:30:21

标签: java python python-3.x tensorflow

我正在尝试将预训练模型(使用python)加载到java项目中。

问题是

rd /?

代码

2>nul

训练和保存模型的python代码

 Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
    at org.tensorflow.Tensor.scalarFloat(Native Method)
    at org.tensorflow.Tensor.floatValue(Tensor.java:279)

似乎我已经成功加载了模型和图形,因为

    float[] arr=context.csvintarr(context.getPlayer(playerId));
    float[][] martix={arr};
    try (Graph g=model.graph()){
        try(Session s=model.session()){

            Tensor y=s.runner().feed("input/input", Tensor.create(martix))
            .fetch("out/predict").run().get(0);
            logger.info("a {}",y.floatValue());
        }
    }

成功打印出名称。

1 个答案:

答案 0 :(得分:1)

错误消息表明输出张量不是浮点值标量,因此它可能是更高维度张量(向量,矩阵)。

您可以使用System.out.println(y.toString())或专门使用y.shape()来了解张量的形状。在您的Python代码中,这对应于y.shape

对于非标量,使用y.copyTo获取浮点数组(对于向量)或浮点数组数组(对于矩阵)等。

例如:

System.out.println(y);
// If the above printed something like:
// "FLOAT tensor with shape [1]"
// then you can get the values using:
float[] vector = y.copyTo(new float[1]);

// If the shape was something like [2, 3]
// then you can get the values using:
float[][] matrix = y.copyTo(new float[2][3]);

有关floatValue() vs copyTo vs writeTo的更多信息,请参阅Tensor javadoc

希望有所帮助。