n维数组的通用Cython代码

时间:2018-11-01 18:20:35

标签: python cython n-dimensional

我正在编写很多Cython代码,这些代码需要推广到 n 个维度(大多数情况下,这将是 n = 1、2或3)。

以下面的代码为例,这是元素的简单总和。

cimport cython
from cython cimport floating

cpdef floating nd_sum(floating [:, :, :] arr):
    cdef:
        Py_ssize_t [3] N = [arr.shape[0], arr.shape[1], arr.shape[2]]
        Py_ssize_t [3] i
        floating total = 0

    for i[0] in range(N[0]):
        for i[1] in range(N[1]):
            for i[2] in range(N[2]):
                total += arr[i[0], i[1], i[2]]

    return total

从概念上讲,对 n 维的泛化非常明显。但是,我似乎无法在代码中实现它。我讨厌必须为 n = 1,2,3,...

的每种情况基本上复制此代码

通常,我还需要能够在每个维度上访问每个数组元素的邻居(请考虑 n 维内核卷积等),这使得使数组变平的事情变得不可行。

我的主要问题是:

  • 我如何告诉Cython期望输入任意维数的数组?我想我可能不得不在这里诉诸np.ndarray ...

  • 在给出 n 个索引的数组的情况下,如何从 n 维数组中获取元素?

  • 如何概括 n 嵌套循环的构造?递归在这里似乎是不可避免的,但是如何以这种方式访问​​所有 n 个索引?

0 个答案:

没有答案