NumPy复杂数组和回调

时间:2016-01-09 10:35:16

标签: python numpy callback

我想将这个问题How to access arrays passed to ctypes callbacks as numpy arrays?扩展到复数的情况。

test.py

#!/usr/bin/env python

import numpy as np
import numpy.ctypeslib as npct
import ctypes
import os.path

array_1d_double = npct.ndpointer(dtype=np.complex128, ndim=1, flags='CONTIGUOUS')

callback_func = ctypes.CFUNCTYPE(
    None,            # return
    ctypes.POINTER(np.complex128), # x
    ctypes.c_int     # n
)

libtest = npct.load_library('libtest', os.path.dirname(__file__))
libtest.callback_test.restype = None
libtest.callback_test.argtypes = [array_1d_double, ctypes.c_int, callback_func]


@callback_func
def callback(x, n):
    x = npct.as_array(x, (n,))
    print("x: {0}, n: {1}".format(x[:n], n))


if __name__ == '__main__':
    x = np.array([20, 13, 8, 100, 1, 3], dtype=np.complex128)
    libtest.callback_test(x, x.shape[0], callback)

test.c的

#include <complex.h>

typedef void (*callback_t)(
    void* *x,
    int n
);

void callback_test(void** x, int n, callback_t callback) {
    _Complex double* cx = (_Complex double*)x;
        for(int i = 1; i <= 5; i++) {

                for(int j = 0; j < n; j++) {
                        cx[j] = cx[j] / i;
                }

                callback(x, n);
        }
}

给我:

  

追踪(最近的呼叫最后):
    文件“test.py”,第12行,在&lt; module&gt;中       ctypes.POINTER(np.complex128),#x
  TypeError: type 必须具有存储信息

有关如何解决这个问题的想法吗?

1 个答案:

答案 0 :(得分:1)

ctypes无法直接处理numpy类型。因此“POINTER(np.complex128)”会导致错误。您可以使用ndpointer结果作为类型,并删除as_array中的shape参数。这对我有用:

#!/usr/bin/env python

import numpy as np
import numpy.ctypeslib as npct
import ctypes
import os.path

array_1d_double = npct.ndpointer(dtype=np.complex128, ndim=1, flags='CONTIGUOUS')

callback_func = ctypes.CFUNCTYPE(
    None,            # return
    array_1d_double,
    ctypes.c_int     # n
)

libtest = npct.load_library('libtest', os.path.dirname(__file__))
libtest.callback_test.restype = None
libtest.callback_test.argtypes = [array_1d_double, ctypes.c_int, callback_func]


@callback_func
def callback(x, n):
    x._shape_ = (n,)
    x = npct.as_array(x)
    print("x: {0}, n: {1}".format(x[:n], n))



if __name__ == '__main__':
    x = np.array([20, 13, 8, 100, 1, 3, 11, 12, 13], dtype=np.complex128)
    libtest.callback_test(x, x.shape[0], callback)