Cython自定义数据类型ndarray

时间:2018-06-14 12:40:28

标签: numpy cython custom-data-type

我为我的np.ndarrays创建了一个dtype:

particle_t = np.dtype([
    ('position', float, 2),
    ('momentum', float, 2),
    ('velocity', float, 2),
    ('force', float, 2),
    ('charge', int, 1),
])

根据官方的例子,可以打电话:

def my_func(np.ndarray[dtype, dim] particles):

但是当我尝试编译时:

def tester(np.ndarray[particle_t, ndim = 1] particles):

我收到Invalid type错误。我见过的另一种使用方法是使用像int[:]这样的内存视图。尝试def tester(particle_t[:] particles):会导致: 'particle_t' is not a type identifier

我该如何解决这个问题?

1 个答案:

答案 0 :(得分:2)

显然,就Cython而言,particle_t不是类型而是Python对象。

类似于np.int32是一个Python对象,因此

def tester(np.ndarray[np.int32] particles):     #doesn't work!
       pass

无法工作,您需要使用相应的类型,即np.int32_t

 def tester(np.ndarray[np.int32_t] particles):  #works!
      pass

但是particle_t的相应类型是什么?您需要创建一个打包的结构,它将镜像您的numpy类型。这是一个简化版本:

#Python code:
particle_t = np.dtype([
    ('position', np.float64, 2), #It is better to specify the number of bytes exactly!
    ('charge', np.int32, 1),  #otherwise you might be surprised...
])

和相应的Cython代码:

%%cython
import numpy as np
cimport numpy as np

cdef packed struct cy_particle_t:
    np.float64_t position_x[2]
    np.int32_t   charge

def tester(np.ndarray[cy_particle_t, ndim = 1] particles):
    print(particles[0])

它不仅可以编译和加载,而且还可以像宣传的那样工作:

>>> t=np.zeros(2, dtype=particle_t)
>>> t[:]=42
>>> tester(t)
{'charge': 42, 'position_x': [42.0, 42.0]}