我正在尝试将一些工作现有的numpy / python代码移植到cython
我遇到的一个问题是我不能在cython中对多维数组使用元组索引,而在python / numpy中它确实有用。
这是一个简单的mwe:
cython_indexing.pyx
# cython: boundscheck=False
# cython: wraparound=False
def loop(int axis, double[:, :, :] a, double[:, :, :] b):
cdef:
int k, j, i
tuple q, qp1
for k in range(a.shape[0]):
for j in range(a.shape[1]):
for i in range(a.shape[2]):
q = (k, j, i)
if axis == 0:
qp1 = (k + 1, j, i)
elif axis == 1:
qp1 = (k, j + 1, i)
elif axis == 2:
qp1 = (k, j, i + 1)
# ...
# some other operations
# with heavy reuse of q, qp1
# ...
a[q] = a[q] - (b[qp1] - b[q])
test_indexing.py
import pyximport; pyximport.install()
import numpy as np
from cython_indexing import loop
a = np.arange(27).astype('float').reshape(3, 3, 3)
b = a**2
for axis in (0, 1, 2):
loop(axis, a, b)
此示例在编译时在b[qp1] - b[q]
上抛出错误:
Invalid operand types for '-' (double[:, :]; double[:, :])
是否有任何简单的解决方案 NOT 涉及更改代码架构?
答案 0 :(得分:0)
基本问题是Cython不知道元组在编译时有多大,所以它无法在编译时明智地进行数组索引 - 它不知道有多少维度它必须返回的数组。 (看起来它只是让人感到困惑,但即使它确实有效,它也必须采用一个通用的Python __getitem__
代码路径,因此你不会加快速度。
你可以做出两个(不太难)的改变。首先是当你说
时做我认为你想要避免的事情任何简单的解决方案 NOT 涉及更改代码架构
是使用3个整数而不是元组:
cdef:
int q0, q1, q2, qp1_0, qp1_1, qp1_2
# ....
a[q0,q1,q2] = a[q0,q1,q2] - (b[qp1_0,qp1_1,qp1_2] - b[q0,q1,q2])
第二个不是使用Cython" typed-memoryview"界面,并让a
和b
无类型:
def loop(int axis, a, b):
这将使索引与元组一起工作(如在纯Python中),但不会比纯Python快得多。
不幸的是,这是一个权衡:如果你想要更快的速度,那么你必须避免使用像元组这样的Python对象。