Python numpy矩阵乘法与一个对角矩阵

时间:2017-06-06 11:06:49

标签: numpy matrix-multiplication

我有两个阵列A(4000,4000),其中只有对角线填充数据,B(4000,5)填充数据。有没有办法将这些数组乘以(点)比numpy.dot(a,b)函数更快?

到目前为止,我发现(A * B.T).T应该更快(其中A是一维(4000,),用对角线元素填充),但结果大约是慢两倍。

在A是诊断数组的情况下,有更快的方法来计算B.dot(A)吗?

2 个答案:

答案 0 :(得分:3)

您可以简单地提取对角线元素,然后执行广播元素乘法。

因此,替换B*A将是 -

np.multiply(np.diag(B)[:,None], A)

A.T*B -

np.multiply(A.T,np.diag(B))

运行时测试 -

In [273]: # Setup
     ...: M,N = 4000,5
     ...: A = np.random.randint(0,9,(M,N)).astype(float)
     ...: B = np.zeros((M,M),dtype=float)
     ...: np.fill_diagonal(B, np.random.randint(11,99,(M)))
     ...: A = np.matrix(A)
     ...: B = np.matrix(B)
     ...: 

In [274]: np.allclose(B*A, np.multiply(np.diag(B)[:,None], A))
Out[274]: True

In [275]: %timeit B*A
10 loops, best of 3: 32.1 ms per loop

In [276]: %timeit np.multiply(np.diag(B)[:,None], A)
10000 loops, best of 3: 33 µs per loop

In [282]: np.allclose(A.T*B, np.multiply(A.T,np.diag(B)))
Out[282]: True

In [283]: %timeit A.T*B
10 loops, best of 3: 24.1 ms per loop

In [284]: %timeit np.multiply(A.T,np.diag(B))
10000 loops, best of 3: 36.2 µs per loop

答案 1 :(得分:0)

看来我最初声称(A * B.T).T较慢是不正确的。

from timeit import default_timer as timer
import numpy as np

##### Case 1
a = np.zeros((4000,4000))
np.fill_diagonal(a, 10)
b = np.ones((4000,5))

dot_list = []

def time_dot(a,b):
    start = timer()
    c = np.dot(a,b)
    end = timer()
    return end - start

for i in range(100):
    dot_list.append(time_dot(a,b))

print np.mean(np.asarray(dot_list))

##### Case 2
a = np.ones((4000,))
a = a * 10
b = np.ones((4000,5))

shortcut_list = []

def time_quicker(a,b):
    start = timer()
    c = (a*b.T).T
    end = timer()
    return end - start

for i in range(100):
    shortcut_list.append(time_quicker(a,b))

print np.mean(np.asarray(shortcut_list))


##### Case 3
a = np.zeros((4000,4000)) #diagonal matrix
np.fill_diagonal(a, 10)
b = np.ones((4000,5))

case3_list = []

def function(a,b):
    start = timer()
    np.multiply(b.T,np.diag(a))
    end = timer()
    return end - start

for i in range(100):
    case3_list.append(function(a,b))

print np.mean(np.asarray(case3_list))

结果:

0.119120892431

0.00010633951868

0.00214490709662

所以第二种方法是最快的