是否有“增强的”numpy / scipy dot方法?

时间:2012-02-28 08:44:32

标签: python math numpy scipy

问题

我想使用numpy或scipy来计算以下内容:

Y = A**T * Q * A

其中Am x n矩阵,A**TA的转置,Qm x m对角矩阵。

由于Q是对角矩阵,因此我只将其对角线元素存储为矢量。

解决Y

的方法

目前我可以想到两种计算Y的方法:

  1. Y = np.dot(np.dot(A.T, np.diag(Q)), A)
  2. Y = np.dot(A.T * Q, A)
  3. 显然,选项2优于选项1,因为不需要使用diag(Q)创建真实矩阵(如果这是numpy真正做的......)
    但是,由于A.T * Qnp.dot(A.T, np.diag(Q))必须与A一起存储才能计算Y,因此这两种方法都存在必须分配比实际需要更多内存的缺陷。 }。

    问题

    是否存在numpy / scipy中的方法可以消除不必要的额外内存分配,而只传递两个矩阵AB(在我的情况下B是{{} 1}})和加权向量A.T以及它?

3 个答案:

答案 0 :(得分:25)

(w / r / t OP的最后一句话:我意识到这种numpy / scipy方法但是没有/或者是OP标题中的问题(即改进NumPy点性能)下面的内容应该有所帮助。换句话说,我的答案是针对提高大多数步骤的性能,包括你的Y功能。

首先,这应该会对你的香草NumPy dot 方法有明显的推动作用:

>>> from scipy.linalg import blas as FB
>>> vx = FB.dgemm(alpha=1., a=v1, b=v2, trans_b=True)

请注意,两个数组v1,v2在

您可以通过数组的 标志 属性访问NumPy数组的字节顺序,如下所示:

>>> c = NP.ones((4, 3))
>>> c.flags
      C_CONTIGUOUS : True          # refers to C-contiguous order
      F_CONTIGUOUS : False         # fortran-contiguous
      OWNDATA : True
      MASKNA : False
      OWNMASKNA : False
      WRITEABLE : True
      ALIGNED : True
      UPDATEIFCOPY : False

更改其中一个数组的顺序,使两者都对齐,只需调用NumPy数组构造函数,传入数组并将相应的 order 标志设置为True

>>> c = NP.array(c, order="F")

>>> c.flags
      C_CONTIGUOUS : False
      F_CONTIGUOUS : True
      OWNDATA : True
      MASKNA : False
      OWNMASKNA : False
      WRITEABLE : True
      ALIGNED : True
      UPDATEIFCOPY : False

您可以通过利用数组顺序对齐来进一步优化减少由复制原始数组引起的过多内存消耗。

但为什么在传递给 dot 之前复制数组?

点积依赖于BLAS操作。这些操作需要以C连续顺序存储的数组 - 这是导致数组被复制的约束。

另一方面,转置 影响副本,但不幸的是在 Fortran命令中返回结果:

因此,要消除性能瓶颈,您需要消除谓词数组复制步骤;要做到这一点,只需要将两个数组以C连续顺序传递给 dot

所以计算 dot(A.T。,A) ,不用制作额外副本:

>>> import scipy.linalg.blas as FB
>>> vx = FB.dgemm(alpha=1.0, a=A.T, b=A.T, trans_b=True)

总而言之,上面的表达式(以及谓词import语句)可以替代dot,以提供相同的功能但性能更好

你可以将该表达式绑定到这样的函数:

>>> super_dot = lambda v, w: FB.dgemm(alpha=1., a=v.T, b=w.T, trans_b=True)

答案 1 :(得分:4)

我只是想把它放在SO上,但是这个拉取请求应该是有帮助的,并且不需要为numpy.dot提供单独的函数。 https://github.com/numpy/numpy/pull/2730 这应该是numpy 1.7

与此同时,我使用上面的例子来编写一个可以替换numpy dot的函数,无论数组的顺序是什么,并正确调用fblas.dgemm。 http://pastebin.com/M8TfbURi

希望这有帮助,

答案 2 :(得分:0)

numpy.einsum 是您正在寻找的:

numpy.einsum('ij, i, ik -> jk', A, Q, A)

这不需要任何额外的内存(尽管通常einsum比BLAS操作更慢)