我正在学习cython,并且已经修改了tutorial中的代码以尝试进行数值微分:
import numpy as np
cimport numpy as np
import cython
np.import_array()
def test3(a, int order=2, int axis=-1):
cdef int i
if axis<0:
axis = len(a.shape) + axis
out = np.empty(a.shape, np.double)
cdef np.flatiter ita = np.PyArray_IterAllButAxis(a, &axis)
cdef np.flatiter ito = np.PyArray_IterAllButAxis(out, &axis)
cdef int a_axis_stride = a.strides[axis]
cdef int o_axis_stride = out.strides[axis]
cdef int axis_length = out.shape[axis]
cdef double value
while np.PyArray_ITER_NOTDONE(ita):
# first element
pt1 = <double*>((<char*>np.PyArray_ITER_DATA(ita)))
pt2 = (<double*>((<char*>np.PyArray_ITER_DATA(ita)) + 1*a_axis_stride))
pt3 = (<double*>((<char*>np.PyArray_ITER_DATA(ita)) + 2*a_axis_stride))
value = -1.5*pt1[0] + 2*pt2[0] - 0.5*pt3[0]
(<double*>((<char*>np.PyArray_ITER_DATA(ito))))[0] = value
for i in range(axis_length-2):
pt1 = (<double*>((<char*>np.PyArray_ITER_DATA(ita)) + i*a_axis_stride))
pt2 = (<double*>((<char*>np.PyArray_ITER_DATA(ita)) + (i+2)*a_axis_stride))
value = -0.5*pt1[0] + 0.5*pt2[0]
(<double*>((<char*>np.PyArray_ITER_DATA(ito)) + (i+1)*o_axis_stride))[0] = value
# last element
pt1 = (<double*>((<char*>np.PyArray_ITER_DATA(ita))+ (axis_length-3)*a_axis_stride))
pt2 = (<double*>((<char*>np.PyArray_ITER_DATA(ita))+ (axis_length-2)*a_axis_stride))
pt3 = (<double*>((<char*>np.PyArray_ITER_DATA(ita))+ (axis_length-1)*a_axis_stride))
value = 1.5*pt3[0] - 2*pt2[0] + 0.5*pt1[0]
(<double*>((<char*>np.PyArray_ITER_DATA(ito))+(axis_length-1)*o_axis_stride))[0] = value
np.PyArray_ITER_NEXT(ita)
np.PyArray_ITER_NEXT(ito)
return out
代码产生正确的结果,并且确实比没有cython的我自己的numpy实现更快。问题如下:
我考虑过只使用一个pt1 = (<double*>((<char*>np.PyArray_ITER_DATA(ita)) + i*a_axis_stride))
语句,然后使用pt1[0]
,pt1[-1]
,pt1[1]
访问数组元素,但这仅在指定的轴是最后一个。如果我要区分不同的轴(不是最后一个轴),则(<double*>((<char*>np.PyArray_ITER_DATA(ita)) + i*a_axis_stride))
指向右边的轴,而pt[-1]
和pt[1]
指向{{1}之前和之后的元素},即沿着最后一个轴。当前版本有效,但是如果我想实现需要更多得分才能进行评估的高阶微分,那么我最终将拥有许多这样的代码行,并且我不确定是否有更好/更有效的方法来使用pt[0]
或
类似pt[1]
之类的东西(沿指定轴)访问相邻点。
还有其他方法可以加快这段代码的速度吗?我正在寻找一些我可能忽略或可能会产生重大影响的细微细节。
答案 0 :(得分:1)
令我有些惊讶的是,我实际上无法使用Cython类型的memoryviews击败您的版本-numpy迭代器看起来非常快。但是,我认为我可以设法显着提高可读性,让您使用Python切片语法。唯一的限制是输入数组必须是C连续的,以使其易于重塑(我认为Fortran连续的也可能有效,但我尚未测试)
基本技巧是将所选轴之前和之后的所有轴展平,以使其成为已知的3D形状,此时您可以使用Cython内存视图。
@cython.boundscheck(False)
def test4(a,order=2,axis=-1):
assert a.flags['C_CONTIGUOUS'] # otherwise the reshape doesn't work
before = np.product(a.shape[:axis])
after = np.product(a.shape[(axis+1):])
cdef double[:,:,::1] a_new = a.reshape((before, a.shape[axis], after)) # this should not involve copying memory - it's just a new view
cdef double[:] a_slice
cdef double[:,:,::1] out = np.empty_like(a_new)
assert a_new.shape[1] > 3
cdef int m,n,i
for m in range(a_new.shape[0]):
for n in range(a_new.shape[2]):
a_slice = a_new[m,:,n]
out[m,0,n] = -1.5*a_slice[0] + 2*a_slice[1] - 0.5*a_slice[2]
for i in range(a_slice.shape[0]-2):
out[m,i+1,n] = -0.5*a_slice[i] + 0.5*a_slice[i+2]
# last element
out[m,-1,n] = 1.5*a_slice[-1] - 2*a_slice[-2] + 0.5*a_slice[-3]
return np.asarray(out).reshape(a.shape)
速度比我想的要慢得多。
在改进代码方面,您可以将步长提高为两倍而不是字节(a_axis_stride_dbl = a_axis_stride/sizeof(double)
),然后索引为pt[i*a_axis_stride_dbl]
。它可能不会获得太大的速度,但是会更具可读性。 (这是您在第1点中提出的问题)