将Java数组快速转换为NumPy数组(Py4J)

时间:2016-08-23 08:24:20

标签: java python numpy py4j

有一些很好的例子如何将NumPy数组转换为Java数组,但反之亦然 - 如何将数据从Java对象转换回NumPy数组。我有一个像这样的Python脚本:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy是Numpy数组,int_array是整数的Java数组 - int[ ][ ]种数组。在Python脚本中初始化为:

    int_class=gateway.jvm.int       # make int class
    double_class=gateway.jvm.double # make double class

    int_array = gateway.new_array(int_class,size_y,size_x)
    double_array = gateway.new_array(double_class,size_y,size_x)

虽然它可以正常工作,但它不是最快的方式而且运行速度相当慢 - 对于~1000x1000阵列,转换时间超过5分钟。

有没有办法在合理的时间内完成这个?

如果我尝试:

    test=np.array(int_array)

我明白了:

    ValueError: invalid __array_struct__

2 个答案:

答案 0 :(得分:3)

我遇到了类似的问题,发现一个解决方案比我测试的情况快了大约220倍:为了将一个1628x120的短整数数组从Java转移到Numpy,运行时间从11秒减少到0.05秒。感谢remove,我开始研究this related StackOverflow question,结果发现py4j有效地将Java字节数组转换为Python字节对象,反之亦然(通过值传递,而不是通过引用传递)。这是一种相当迂回的做事方式,但并不太难。

因此,如果要传输维度为intArray x iMax的整数数组jMax(并且为了示例,我假设这些都存储为实例变量在您的对象中),您可以先编写一个Java函数将其转换为byte [],如下所示:

public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

然后,您可以编写Python 3代码来接收字节数组并将其转换为正确形状的numpy数组:

byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))

答案 1 :(得分:2)

我有一个类似的问题,只是试图绘制我从Java端通过py4j获得的光谱矢量(Java数组)。 这里,通过list()函数实现从Java Array到Python列表的转换。这可能会提供一些线索,如何使用它来填充NumPy数组...

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

这是否比您使用的嵌套for循环更快,我不能说,但它也可以解决问题:

import numpy
from numpy  import array
x = array(wavelengths)
v = array(list(vectors))

mp.plot(x, numpy.rot90(v))