如何在Cython中静态输入gmpy2.mpq列表?

时间:2016-03-25 11:49:14

标签: cython

我实现了一个基于gmpy2.mpq的高斯估计函数。分析后,它被证明是我的程序的瓶颈。我试过Cython来优化它。

Cython太棒了,速度已经增加了两倍左右,但是我试图通过静态输入来加快速度。

当我这样做时,有两个问题:

  1. gmpy2.mpq是一个函数而不是一个类型。如何静态输入呢?
  2. 如何静态输入给定类型的列表?
  3. 如果列表有更快的替代方案,那将会很棒。

    我已经附上了代码,仅供参考。

    def gauss_estimate(a, b):
        """Gauss estimation for integers
    
        :param a:a n*m 2d sequence of integers, where n>=m
        :param b:an n-length 1d sequence of integers
        :returns: a m-length 1d sequence of mpq that a@x=b, gmpy2.mpq is a fast
            implementing of fraction
        :raises: a SingularError if the rank of a is smaller than m
        :raises: a NoSolutionError if the rank of a is larger than m
        """
        cdef int n, m, i, j
        cdef list aa, bb
        if isinstance(a, ndarray):
            a = a.tolist()
        if isinstance(b, ndarray):
            b = b.tolist()
        aa = [[mpq(aii) for aii in ai] for ai in a]
        n = len(aa)
        m = len(aa[0])
        if n < m:
            raise ValueError('Wrong shape of a')
        for ai in aa:
            if len(ai) != m:
                raise ValueError('Wrong shape of ai')
        bb = [mpq(bi) for bi in b]
        if len(bb) != n:
            raise ValueError('Wrong shape of b')
        for i in range(m):
            if aa[i][i] == 0:
                for j in range(i, n):
                    if aa[j][i] != 0:
                        aa[i], aa[j] = aa[j], aa[i]
                        bb[i], bb[j] = bb[j], bb[i]
                        break
                else:
                    raise SingularError('The rank of a is smaller than m')
            bb[i] /= aa[i][i]
            for j in reversed(range(i, m)):
                aa[i][j] /= aa[i][i]
            for j in range(i+1, n):
                bb[j] -= aa[j][i] * bb[i]
                for k in reversed(range(i, m)):
                    aa[j][k] -= aa[j][i] * aa[i][k]
                assert aa[j][i] == 0
        for i in range(m, n):
            if bb[i] != 0:
                raise NoSolutionError('No solution found')
        for i in reversed(range(m)):
            for j in range(0, i):
                bb[j] -= bb[i] * aa[j][i]
        for i in range(m):
            assert aa[i][i] == 1
        return tuple(bb[:m])
    

0 个答案:

没有答案