n维数组上的cython元组索引

时间:2017-02-11 13:28:51

标签: python indexing cython

我正在尝试将一些工作现有的numpy / python代码移植到cython

我遇到的一个问题是我不能在cython中对多维数组使用元组索引,而在python / numpy中它确实有用。

这是一个简单的mwe:

cython_indexing.pyx

# cython: boundscheck=False
# cython: wraparound=False
def loop(int axis, double[:, :, :] a, double[:, :, :] b):
    cdef:
        int k, j, i
        tuple q, qp1

    for k in range(a.shape[0]):
        for j in range(a.shape[1]):
            for i in range(a.shape[2]):
                q = (k, j, i)
                if axis == 0:
                    qp1 = (k + 1, j, i)
                elif axis == 1:
                    qp1 = (k, j + 1, i)
                elif axis == 2:
                    qp1 = (k, j, i + 1)

                # ...
                # some other operations
                # with heavy reuse of q, qp1
                # ...

                a[q] = a[q] - (b[qp1] - b[q])

test_indexing.py

import pyximport; pyximport.install()
import numpy as np
from cython_indexing import loop

a = np.arange(27).astype('float').reshape(3, 3, 3)
b = a**2

for axis in (0, 1, 2):
    loop(axis, a, b)

此示例在编译时在b[qp1] - b[q]上抛出错误:

Invalid operand types for '-' (double[:, :]; double[:, :])

是否有任何简单的解决方案 NOT 涉及更改代码架构?

1 个答案:

答案 0 :(得分:0)

基本问题是Cython不知道元组在编译时有多大,所以它无法在编译时明智地进行数组索引 - 它不知道有多少维度它必须返回的数组。 (看起来它只是让人感到困惑,但即使它确实有效,它也必须采用一个通用的Python __getitem__代码路径,因此你不会加快速度。

你可以做出两个(不太难)的改变。首先是当你说

时做我认为你想要避免的事情
  

任何简单的解决方案 NOT 涉及更改代码架构

是使用3个整数而不是元组:

cdef:
   int q0, q1, q2, qp1_0, qp1_1, qp1_2

# ....
a[q0,q1,q2] = a[q0,q1,q2] - (b[qp1_0,qp1_1,qp1_2] - b[q0,q1,q2])

第二个不是使用Cython" typed-memoryview"界面,并让ab无类型:

def loop(int axis, a, b):

这将使索引与元组一起工作(如在纯Python中),但不会比纯Python快得多。

不幸的是,这是一个权衡:如果你想要更快的速度,那么你必须避免使用像元组这样的Python对象。