我有许多矩阵对要相乘,以下是一个示例代码
D = np.random.rand(100, 10)
B = np.random.rand(10, 100)
split = array([ 3, 6, 11, 14, 18, 25, 31, 38, 45, 52, 60, 67, 84, 88, 90, 95])
DD = np.vsplit(D, split)
BB = np.hsplit(B, split)
G = [ m0@m1 for m0, m1 in zip(DD, BB)]
以下是我电脑上的测试:
In [42]: %timeit [m0@m1 for m0, m1 in zip(DD, BB)]
The slowest run took 10.01 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 24.5 µs per loop
沿着轴0,D
被分成许多具有不同数字行的小数组。沿着轴1,B
被分成许多小数组,这些数组具有不同的数字列。在这里,我使用列表理解来完成我的工作。
列表理解是否是更快的方式来完成这项工作?或者有没有更快的方法在numpy中做这样的事情?可以直接在D
和B
吗?
答案 0 :(得分:1)
如果分割的大小不相等,那么将其映射到单个广播操作的唯一方法是在适当构造的块对角矩阵上进行稀疏矩阵乘法。 例如,您可以这样处理它:
from scipy.sparse import block_diag
def split_dot_sparse(D, B, split):
# create block-diagonal matrices
DD = block_diag(np.vsplit(D, split))
BB = block_diag(np.hsplit(B, split))
# multiply the blocks
DDBB = DD @ BB
# extract the results
return [DDBB[i:j, i:j].toarray() for i, j in zip([0, *split], [*split, D.shape[0] + 1])]
这会产生与列表理解相同的结果:
def split_dot_list_comp(D, B, split):
DD = np.vsplit(D, split)
BB = np.hsplit(B, split)
return [m0@m1 for m0, m1 in zip(DD, BB)]
G1 = split_dot_list_comp(D, B, split)
G2 = split_dot_sparse(D, B, split)
all(np.allclose(*mats) for mats in zip(G1, G2)
# True
不幸的是,稀疏方法比简单列表理解方法慢得多:
%timeit split_dot_list_comp(D, B, split)
# 73.5 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit split_dot_sparse(D, B, split)
# 4.67 ms ± 48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
我们可能能够优化块矩阵的创建或结果的提取(如果你非常小心,甚至可以从数据的无副本视图创建稀疏矩阵),但即便如此矩阵乘积本身比列表理解慢几个因素。 对于不均匀的分裂,您将无法比基线方法做得更好。
如果你甚至分裂,故事会有所不同,因为那时你可以使用numpy广播来快速计算结果。它可能看起来像这样:
def split_dot_broadcasted(D, B, splitsize):
DD = D.reshape(-1, splitsize, D.shape[1])
BB = B.reshape(B.shape[0], -1, splitsize)
return DD @ BB.transpose(1, 0, 2)
这给出了与列表理解方法相同的结果:
splitsize = 5
split = splitsize * np.arange(1, D.shape[0] // splitsize)
G1 = split_dot_list_comp(D, B, split)
G2 = split_dot_broadcasted(D, B, splitsize)
np.allclose(G1, G2)
# True
广播方法的速度提高了几倍:
%timeit split_dot_list_comp(D, B, split)
# 83.6 µs ± 314 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit split_dot_broadcasted(D, B, splitsize)
# 29.3 µs ± 539 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
所以,长话短说:如果分裂不均匀,你可能无法击败你在问题中提出的列表理解。对于偶数大小的分割,使用numpy的广播将是一些更快的因素。