用Python(和Cython)计算两个矩阵的点积的最快方法是什么

时间:2019-08-11 19:02:21

标签: python cython matrix-multiplication

我试图用Cython创建一个Python库,我需要在其中实现点积。我有一种非常简单的计算点积的方法,但是,对于较大的矩阵,它的运行速度不够快。

我花了很多时间来研究这个问题,并试图使其尽快解决,但是我无法使其更快地解决问题。

以下代码显示了我当前计算方式的Python实现:

a = [[1, 2, 3], [4, 5, 6]]
b = [[1], [2], [3]]

def dot(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]

    for i in range(len(c)):
        for j in range(len(c[i])):
            t = 0
            for k in range(len(b)):
                t += a[i][k] * b[k][j]
            c[i][j] = t
    return c

print(dot(a, b))
# [[14], [32]]

这确实给出了正确的计算结果(python [[14], [32]]),但是对于我将要使用的计算时间却太长了。我将如何更快地提供任何帮助,将不胜感激。谢谢

3 个答案:

答案 0 :(得分:4)

您可以为此使用numpy。 Numpy实现了BLAS规范(基本线性代数子程序),它们是线性代数库的低级例程(例如矩阵乘法)的事实上的标准。要获得两个矩阵的点积,例如AB,可以使用以下代码:

A = [[1, 2, 3], [4, 5, 6]]
B = [[1], [2], [3]]

import numpy as np #Import numpy

numpy_a = np.array(A) #Cast your nested lists to numpy arrays
numpy_b = np.array(B)
print(np.dot(numpy_a, numpy_b)) #Print the result

答案 1 :(得分:2)

根据结构的索引成本,您可以通过排除一些操作来提高速度:

def dot(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]
    bt = transpose(b)        # can this be done once cheaply?
    for i in range(len(c)):
        a1 = a[i]
        c1 = c[i]
        for j in range(len(c1)):
            b1 = bt[j]
            t = 0
            for k in range(len(b)):
                t += a1[k] * b1[k]
            c1[j] = t
    return c

内部k循环可以用惯用的Python编写为:

for a2, b2 in zip(a1, b1):
     t += a2 * b2

我不知道cython翻译的速度是否更快。

快速cython还需要将各种变量定义为intfloat等,以便可以进行直接c的翻译,而不是通过通用但昂贵的Python对象进行翻译。我不会尝试重复cython文档。

答案 2 :(得分:2)

您应该注释(,静态键入)所有可能的变量。如果您愿意,以下是我的解决方案:

# mydot.pyx
import numpy as np
cimport cython

def dot_1(a, b):
    c = [[0 for j in range(len(b[i]))] for i in range(len(a))]

    for i in range(len(c)):
        for j in range(len(c[i])):
            t = 0
            for k in range(len(b)):
                t += a[i][k] * b[k][j]
            c[i][j] = t
    return c


@cython.boundscheck(False)  # turn off bounds-checking
@cython.wraparound(False)  # turn off negative index wrapping
def dot_2(double[:, :] A, double[:, :] B):
    cdef Py_ssize_t M = A.shape[0]
    cdef Py_ssize_t Na = A.shape[1]
    cdef Py_ssize_t Nb = B.shape[0]
    cdef Py_ssize_t K = B.shape[1]

    assert Na == Nb

    result = np.empty((M, K), dtype='d')
    cdef double[:, :] C = result

    cdef double t

    for m in range(M):
        for k in range(K):
            t = 0
            for n in range(Na):
                t += A[m, n] * B[n, k]
            C[m, k] = t

    return result

# app.py
import pyximport
from numpy import array
from scipy import median
from timeit import repeat

pyximport.install()
from mydot import dot_1, dot_2


a = array([[1, 2, 3], [4, 5, 6]], dtype='d')
b = array([[1], [2], [3]], dtype='d')

dot_1_t = repeat('dot_1(a, b)', repeat=1000, number=1, globals=globals())
dot_2_t = repeat('dot_2(a, b)', repeat=1000, number=1, globals=globals())

print(f'dot_1 took {median(dot_1_t)*1000} ms.')
print(f'dot_2 took {median(dot_2_t)*1000} ms.')

运行cython --annotate mydot.pyx时,Cython将生成一个注释Cython代码的HTML文件。在那里,黄色高亮显示的颜色越深,生成的C代码的开销就越多(Python)。您可以将两种解决方案(尤其是for循环)进行相互比较。

运行python app.py还可以为您带来更快的结果。当然,如果您提供的小尺寸输入低于某个阈值,那么您将不会看到两者之间有意义的速度差异,因为您没有进行足够的迭代。但是,在达到某个阈值之后,速度差异应该会很大,因为循环中的每次迭代对于您的版本而言都是昂贵的(请参见深黄线)。

最后一句话是,正如该问题下的每个人都已经建议过的那样,当您提供较大尺寸的矩阵时,numpy的函数应该更高效--它们正在使用阻塞的(子)矩阵操作从基础的BLAS和LAPACK实现中提取数据,而不是天真的对索引进行逐一迭代。

PS:如果您不仅想在dot_2上专门研究double,还想在其他有意义的民用类型(例如intfloat)上进行专业研究,则应检查Cython的fused types

编辑。由于后来我的回答被选为答案,因此我想举一个较大的输入示例。如果不是使用上面的app.py,而是使用以下内容:

# app.py
import pyximport
from numpy import array, random as rnd
from scipy import median
from timeit import repeat

pyximport.install()
from mydot import dot_1, dot_2


M = 100
N = 100
K = 1

a = rnd.randn(M, N)
b = rnd.randn(N, K)

dot_1_t = repeat('dot_1(a, b)', repeat=1000, number=1, globals=globals())
dot_2_t = repeat('dot_2(a, b)', repeat=1000, number=1, globals=globals())

print(f'dot_1 took {median(dot_1_t)*1000} ms.')
print(f'dot_2 took {median(dot_2_t)*1000} ms.')

时间安排应类似于以下内容:

dot_1 took 5.218300502747297 ms.
dot_2 took 0.013017997844144702 ms.