如何忽略矩阵乘法中的零?

时间:2019-05-24 15:21:29

标签: python numpy math

假设我有一个10000 x 10000的矩阵W,它具有随机数,以及两个10000暗矢量U和V,U中具有随机数,V填充有零。 使用numpy或pytorch,计算U @ W和V @ W花费的时间相同。我的问题是,有没有一种方法可以优化矩阵乘法,以便在计算过程中跳过或忽略零,因此V @ W之类的东西将被更快地计算?

import numpy as np
W = np.random.rand(10000, 10000)

U = np.random.rand(10000)
V = np.zeros(10000)

y1 = U @ W
y2 = V @ W
# computing y2 should take less amount of time than y1 since it always returns zero vector.

2 个答案:

答案 0 :(得分:3)

您可以使用scipy.sparse类来提高性能,但这完全取决于矩阵。例如,使用V作为稀疏矩阵所获得的性能将会很好。通过将U转换为稀疏矩阵而获得的结果将不是很大,或者实际上可能会降低性能(在U实际上是密集的情况下)。

import numpy as np
import scipy.sparse as sps

W = np.random.rand(10000, 10000)
U = np.random.rand(10000)
V = np.zeros(10000)

%timeit U @ W
125 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit V @ W
128 ms ± 6.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Vsp = sps.csr_matrix(V)
Usp = sps.csr_matrix(U)
Wsp = sps.csr_matrix(W)

%timeit Vsp.dot(Wsp)
1.34 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 
%timeit Vsp @ Wsp
1.39 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit Usp @ Wsp
2.37 s ± 84.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

如您所见,对V @ W使用稀疏方法有很大的改进,但是实际上U @ W的性能降低了,因为U或W中的任何一项都不为零。

答案 1 :(得分:0)

In [274]: W = np.random.rand(10000, 10000) 
     ...:  
     ...: U = np.random.rand(10000) 
     ...: V = np.zeros(10000)                                                                            
In [275]: timeit U@W                                                                                     
125 ms ± 263 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [276]: timeit V@W                                                                                     
153 ms ± 18.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

现在考虑V的100个元素为非零(1s)的情况。稀疏的实现可能是:

In [277]: Vdata=np.ones(100); Vind=np.arange(0,10000,100)                                                
In [278]: Vind.shape                                                                                     
Out[278]: (100,)
In [279]: timeit Vdata@W[Vind,:]                                                                         
4.99 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这一次让我有些惊讶,认为W的索引可以抵消乘法时间。

让我们更改V来验证结果:

In [280]: V[Vind]=1                                                                                      
In [281]: np.allclose(V@W, Vdata@W[Vind,:])  

如果我必须先找到非零元素怎么办?

In [282]: np.allclose(np.where(V),Vind)                                                                  
Out[282]: True
In [283]: timeit idx=np.where(V); V[idx]@W[idx,:]                                                        
5.07 ms ± 77.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

W的大小(尤其是第二维)可能是此加速的重要因素。在这些大小下,内存管理与原始乘法一样会影响速度。

===

在这种情况下,sparse比我预期的要好(其他测试建议我需要大约1%的稀疏度才能获得时间优势)

In [294]: from scipy import sparse                                                                       
In [295]: Vc=sparse.csr_matrix(V)                                                                        
In [296]: Vc.dot(W)                                                                                      
Out[296]: 
array([[46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
        45.54413903, 48.28613399]])
In [297]: V.dot(W)                                                                                       
Out[297]: 
array([46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
       45.54413903, 48.28613399])
In [298]: np.allclose(Vc.dot(W),V@W)                                                                     
Out[298]: True

In [299]: timeit Vc.dot(W)                                                                               
1.48 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

即使是稀疏的创建:

In [300]: timeit Vm=sparse.csr_matrix(V); Vm.dot(W)                                                      
2.01 ms ± 7.89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)