Java中的Tensorflow:inferenceInterface.fetch转换为多维数组

时间:2018-08-17 15:12:04

标签: java python arrays tensorflow shape

我正在Android上的Java中使用经过训练的Tensorflow模型。我正在尝试提取中间操作的输出。

我提取的张量具有形状(150、150、256)。

我已声明输出目标为

private float[] hybridValues;
hybridValues = new float[150 * 150 * 256];

然后我使用获取输出。

inferenceInterface.fetch(OUTPUT_NODE, hybridValues);

这些值很好,但是它们存储为一维数组。有没有办法让inferenceinterface.fetch返回多维数组?

我尝试将hybridValue声明为三维浮点数组,但由于fetch方法期望使用1D数组,因此无法正常工作。

最终目标是将我的输出传递给Python程序,该程序会将值馈送到形状相同(150、150、256)的张量。

为了进行比较,Python a_output = graph.get_tensor_by_name('a2b_generator/Conv_7/Relu:0') 返回的ndarray的值与目标张量的形状相同。

1 个答案:

答案 0 :(得分:0)

我在桌面上使用tensorflow和java(可能会有所不同),我要做的就是创建一个具有正确大小的多维数组,然后将值复制到例如Tensor.copyTo(Object dst)