将带有布尔索引的python函数转换为cython

时间:2019-03-06 17:26:42

标签: python cython

下面,我要用Cythonize用纯Python编写一个函数。

def do_stuff(M_i, C_i):
    return M_i.dot(C_i).dot(M_i)

def foo(M, C):
    '''
    M : np.array
        N x J matrix
    C : np.array
        J x J matrix
    '''

    N = M.shape[0]

    tot = 0

    for i in range(N):
        nonmiss = ~np.isnan(M[i, :])
        M_i = M[i, nonmiss] # select non empty components of M
        C_i = C[nonmiss, :] # select corresponding components of C
        C_i = C_i[:, nonmiss] # select corresponding components of C

        tot = tot + do_stuff(M_i, C_i)

    return tot

假设我知道如何对函数do_stuff进行Cythonize。我感兴趣的实际do_stuff函数比上面的函数复杂,但我想提供一个示例。实数do_stuff函数除了矩阵乘法之外,还计算行列式并取逆。

我的主要问题与创建M_iC_i子矢量和子矩阵有关。我不确定在Cython中是否可以执行相同的布尔索引。如果可以的话,我不知道怎么做。但是我可以从我知道的Cython开始。

def foo_c(double[:, ::1] M, double[:, ::1] C):

    cdef int N = M.shape[0]
    cdef double tot = 0
    ...

    for i in range(N):
        ...
        tot = tot + do_stuff_c(M_i, C_i)

    return tot

2 个答案:

答案 0 :(得分:0)

在这里您可能不会获得太大的速度,因为无论如何,Numpy中的布尔索引是在C中实现的,因此应该相当快。您要避免的主要事情是创建一些不必要的中间体(这涉及内存分配,因此可能很慢)。

您要为M_iC_i创建临时数组,它们的大小可能最大(即JJxJ)。遍历isnan(M_I)时,您会跟踪实际存储了多少个值。然后最后将M_iC_i修剪为仅使用的部分:

未经测试的代码:

for i in range(N):
    filled_count_j = 0
    M_i = np.empty((M.shape[1],))
    C_i = np.empty((M.shape[1],M.shape[1]))
    for j in range(M.shape[1]):
        if ~isnan(M[i,j]):
            M_i[filled_count] = M[i,j]

            filled_count_k = 0
            for k in range(M.shape[1]):
                if ~isnan(M[i,k]):
                    C_i[filled_count_j,filled_count_k] = C[j,k]
                    filled_count_k += 1
            filled_count_j += 1
     M_i = M_i[:filled_count_j]
     C_i = C_i[:filled_count_j,:filled_count_j]

答案 1 :(得分:0)

如果您准备好使用纯数组,您可以通过对所有代码进行 cythonizing 来提高一些速度(可能是 1.2x-5x)。但这可能不值得您花费时间,失去 numpy.conf 的便利性。这取决于您的项目中的优先级是什么以及您的 do_stuff() 函数有多大。