如何用面具加速numpy dot产品?

时间:2016-05-31 22:23:09

标签: arrays performance numpy matrix-multiplication

我有2个numpy数组,m1m2其中m1的大小(nx1)和m2是大小(1xn),我想执行乘法{ {1}}生成大小为m1.dot(m2)的矩阵(nxn)

我想通过仅使用mm_approx中的最高k个元素并使所有其他元素为0(所有元素都为正)来计算近似m1

我正在尝试加速乘法,因为我的大小m2很大(~10k)。我想选一个小的n说100并真正加速乘法。我尝试使用numpy稀疏矩阵,它确实使得点积更快,但是将m1和m2转换为稀疏矢量非常慢。我怎样才能做到这一点?我觉得面具可能是实现这一目标的一种方式,但不确定如何?

1 个答案:

答案 0 :(得分:3)

可以使用np.argpartition来获取最大k元素和np.ix_的索引,以便从m1和{{选择和设置所选元素的点积来解决1}}。因此,我们基本上有两个阶段来实现这一点,如下所述。

首先,获取与m2k中最大的m1元素对应的索引,如此 -

m2

最后,设置输出数组。使用m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel() m2_idx = np.argpartition(-m2,k)[:,:k].ravel() 分别沿行和列广播np.ix_m1索引,以选择要设置的输出数组中的元素。接下来,计算来自m2k的最高m1元素之间的点积,可以使用{m2m1获取索引。 1}}和m2,就像这样 -

m1_idx

让我们通过针对另一个实现运行示例运行验证实现,该实现在m2_idx out = np.zeros((n,n)) out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx]) n-k 0 m1中明确设置较低的m2元素然后执行点积。这是执行检查的示例运行 -

1)输入:

In [170]: m1
Out[170]: 
array([[ 0.26980423],
       [ 0.30698416],
       [ 0.60391089],
       [ 0.73246763],
       [ 0.35276247]])

In [171]: m2
Out[171]: array([[ 0.30523552, 0.87411242, 0.01071218, 0.81835438, 0.21693231]])

In [172]: k = 2

2)运行建议的实施:

In [173]: # Proposed solution code
     ...: m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
     ...: m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
     ...: out = np.zeros((n,n))
     ...: out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
     ...: 

3)使用替代实现来获得输出:

In [174]: # Explicit setting of lower n-k elements to zeros for m1 and m2
     ...: m1[np.argpartition(-m1,k,axis=0)[k:]] = 0
     ...: m2[:,np.argpartition(-m2,k)[:,k:].ravel()] = 0
     ...: 

In [175]: m1  # Verify m1 and m2 have lower n-k elements set to 0s
Out[175]: 
array([[ 0.        ],
       [ 0.        ],
       [ 0.60391089],
       [ 0.73246763],
       [ 0.        ]])

In [176]: m2
Out[176]: array([[ 0.       , 0.87411242, 0.        , 0.81835438, 0.        ]])

In [177]: m1.dot(m2)  # Use m1.dot(m2) to directly get output. This is expensive.
Out[177]: 
array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
       [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])

4)验证我们的建议实施:

In [178]: out   # Print output from proposed solution obtained earlier
Out[178]: 
array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.52788601,  0.        ,  0.49421312,  0.        ],
       [ 0.        ,  0.64025905,  0.        ,  0.59941809,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])