可广播的Numpy点

时间:2015-07-01 01:32:32

标签: python algorithm numpy matrix-multiplication blas

我有一个维度为(n0,n2)的数组H和一个维度为(n0,n1,n2,n3)的数组W,我想进行以下操作:

(H[:, None, :, None] * W).sum(axis=(0, 2))

据我所知,上面一行不使用BLAS库。有没有办法使用numpy.dot或类似的函数使用BLAS进行相同的计算(并且仍然没有在内存中多次复制数组H)?

1 个答案:

答案 0 :(得分:1)

你已经找到了一种方法;我知道另外两个人。

一个小例子

In [365]: n0,n1,n2,n3=2,3,4,5
In [366]: H=np.ones((n0,n2));W=np.ones((n0,n1,n2,n3))

比较时间是:

In [362]: timeit np.tensordot(H,W,[(0,1),(0,2)])
10000 loops, best of 3: 32.8 µs per loop

In [363]: timeit np.einsum('ik,ijkl',H,W)
100000 loops, best of 3: 10.7 µs per loop

In [364]: timeit (H[:,None,:,None]*W).sum(axis=(0,2))
10000 loops, best of 3: 29.5 µs per loop

tensordot重新整形并转置输入,以便调用np.doteinsum对字符串进行解码,并在C中执行自己的nditer

https://stackoverflow.com/a/31129207/901925有另一个多维dot的时间,涉及(100,)*(10,100,100)*(100,)数组。