使用Numpy einsum的GEMM

时间:2016-10-11 11:36:46

标签: python numpy blas

单个numpy einsum语句是否可以复制gemm功能?标量和矩阵乘法似乎很简单,但我还没有找到如何使“+”工作。如果它更简单,D = alpha * A * B + beta * C是可接受的(实际上更好)

alpha = 2
beta = 3
A = np.arange(9).reshape(3, 3)
B = A + 1
C = B + 1

left_part = alpha*np.dot(A, B)
print(left_part)
left_part = np.einsum(',ij,jk->ik', alpha, A, B)
print(left_part)

1 个答案:

答案 0 :(得分:1)

这里似乎有些混淆:np.einsum处理可以按以下形式强制转换的操作:broadcast-multiply-reduce。元素总和不是其范围的一部分。

为什么你需要这种乘法的原因是将这些操作写出来"天真地"可能会快速超过内存或计算资源。例如,考虑矩阵乘法:

import numpy as np
x, y = np.ones((2, 2000, 2000))

# explicit loop - ridiculously slow
a = sum(x[:,j,np.newaxis] * y[j,:] for j in range(2000))

# explicit broadcast-multiply-reduce: throws MemoryError
a = (x[:,:,np.newaxis] * y[:,np.newaxis,:]).sum(1)

# einsum or dot: fast and memory-saving
a = np.einsum('ij,jk->ik', x, y)

然而,爱因斯坦的惯例分解为加法,所以你 可以简单地将你的BLAS问题写成:

d = np.einsum(',ij,jk->ik', alpha, a, b) + np.einsum(',ik', beta, c)

具有最小的内存开销(如果您真的关心内存,可以将大部分内容重写为就地操作)和持续的运行时开销(两次python-to-C调用的成本)。

所以关于性能,这看起来像是一个对我来说过早优化的情况:你真的已经证实将类似GEMM的操作分成两个单独的numpy调用是你代码中的瓶颈吗?如果确实如此,那么我建议如下(按照增加的参与度的顺序):

  1. 仔细试试!scipy.linalg.blas.dgemm。如果你得到,我会感到惊讶   性能明显提高,因为dgemm通常只是。{   建筑自己。

  2. 尝试表达式编译器(基本上你是在提议   像Theano这样的事情。

  3. 使用Cython或C编写您自己的generalised ufunc