使用Cython,有没有办法编写快速通用函数,这些函数适用于具有不同维度的数组?例如,对于这个简单的去除函数的情况:
import numpy as np
cimport numpy as np
ctypedef np.uint8_t DTYPEb_t
ctypedef np.complex128_t DTYPEc_t
def dealiasing1D(DTYPEc_t[:, :] data,
DTYPEb_t[:] where_dealiased):
"""Dealiasing data for 1D solvers."""
cdef Py_ssize_t ik, i0, nk, n0
nk = data.shape[0]
n0 = data.shape[1]
for ik in range(nk):
for i0 in range(n0):
if where_dealiased[i0]:
data[ik, i0] = 0.
def dealiasing2D(DTYPEc_t[:, :, :] data,
DTYPEb_t[:, :] where_dealiased):
"""Dealiasing data for 2D solvers."""
cdef Py_ssize_t ik, i0, i1, nk, n0, n1
nk = data.shape[0]
n0 = data.shape[1]
n1 = data.shape[2]
for ik in range(nk):
for i0 in range(n0):
for i1 in range(n1):
if where_dealiased[i0, i1]:
data[ik, i0, i1] = 0.
def dealiasing3D(DTYPEc_t[:, :, :, :] data,
DTYPEb_t[:, :, :] where_dealiased):
"""Dealiasing data for 3D solvers."""
cdef Py_ssize_t ik, i0, i1, i2, nk, n0, n1, n2
nk = data.shape[0]
n0 = data.shape[1]
n1 = data.shape[2]
n2 = data.shape[3]
for ik in range(nk):
for i0 in range(n0):
for i1 in range(n1):
for i2 in range(n2):
if where_dealiased[i0, i1, i2]:
data[ik, i0, i1, i2] = 0.
这里,我需要三维函数用于一维,二维和三维情况。有没有一种很好的方法来编写一个能够完成所有(合理)维度工作的函数?
PS:在这里,我尝试过使用内存视图,但我不确定这是否是正确的方法。我很惊讶在if where_dealiased[i0]: data[ik, i0] = 0.
命令生成的带注释的html中,cython -a
行不是白色的。有什么不对吗?
答案 0 :(得分:2)
我要说的第一件事是有理由想要保留3个函数,使用更通用的函数,你可能会错过来自cython编译器和c编译器的优化。
创建一个包装这3个函数的函数是非常可行的,它只需要将两个数组作为python对象,检查形状,并调用相关的其他函数。
但是如果要尝试这个,那么我要尝试的只是编写最高维度的函数,然后使用较低维度的数组通过使用new axis表示法将它们重新编写为更高维度的数组:
cdef np.uint8_t [:] a1d = np.zeros((256, ), np.uint8) # 1d
cdef np.uint8_t [:, :] a2d = a1d[None, :] # 2d
cdef np.uint8_t [:, :, :] a3d = a1d[None, None, :] # 3d
a2d[0, 100] = 42
a3d[0, 0, 200] = 108
print(a1d[100], a1d[200])
# (42, 108)
cdef np.uint8_t [:, :] data2d = np.zeros((128, 256), np.uint8) #2d
cdef np.uint8_t [:, :, :, :] data4d = data2d[None, None, :, :] #4d
data4d[0, 0, 42, 108] = 64
print(data2d[42, 108])
# 64
如您所见,内存视图可以转换为更高的维度,并可用于修改原始数据。您可能仍然希望编写一个包装函数,在将新视图传递给最高维函数之前执行这些技巧。我怀疑这个技巧在你的情况下会很好用,但是你必须要知道它是否能用你的数据做你想做的事情。
使用PS:,有一个非常简单的解释。 '额外代码'是生成索引错误,类型错误的代码,它允许您使用[-1]从数组末尾而不是start(环绕)索引。 您可以通过使用compiler directives禁用这些额外的python功能并将其减少为c数组功能,例如从整个文件中删除这些额外的代码,您可以在文件的开头包含注释:
# cython: boundscheck=False, wraparound=False, nonecheck=False
编译器指令也可以使用装饰器在函数级别应用。该文件解释说。
答案 1 :(得分:1)
您可以使用numpy.ndindex()
使用strided
对象的np.ndarray
属性以一般方式阅读扁平数组,以便位置由以下各项确定:
indices[0]*strides[0] + indices[1]*strides[1] + ... + indices[n]*strides[n]
当(strides*indices).sum()
是一维数组时,可以轻松完成strides
。下面的代码显示了如何构建一个工作示例:
#cython profile=True
#blacython wraparound=False
#blacython boundscheck=False
#blacython nonecheck=False
#blacython cdivision=True
cimport numpy as np
import numpy as np
def readNDArray(x):
if not isinstance(x, np.ndarray):
raise ValueError('x must be a valid np.ndarray object')
if x.itemsize != 8:
raise ValueError('x.dtype must be float64')
cdef np.ndarray[double, ndim=1] v # view of x
cdef np.ndarray[int, ndim=1] strides
cdef int pos
shape = list(x.shape)
strides = np.array([s//x.itemsize for s in x.strides], dtype=np.int32)
v = x.ravel()
for indices in np.ndindex(*shape):
pos = (strides*indices).sum()
v[pos] = 2.
return np.reshape(v, newshape=shape)
如果原始数组是C连续,此算法将不会复制:
def main():
# case 1
x = np.array(np.random.random((3,4,5,6)), order='F')
y = readNDArray(x)
print(np.may_share_memory(x, y))
# case 2
x = np.array(np.random.random((3,4,5,6)), order='C')
y = readNDArray(x)
print np.may_share_memory(x, y)
return 0
结果:
False
True