覆盖特殊矩阵的matmul

时间:2019-05-08 17:16:18

标签: python numpy

我有一个应用程序,在该应用程序中,我必须进行很多(成本很高)矩阵乘法。这些矩阵中有许多(但不是全部)具有特殊的结构,对于我来说,很容易定义自定义数据表示形式和自定义矩阵乘法例程。我希望能够将它们与表示为普通2d ndarrays的任意矩阵混合并匹配。

举一个最简单的例子,我想做这样的事情:

import numpy as np

class IdentityMatrix:
    def __matmul__(self, other):
        return other

    def __rmatmul__(self, other):
        return other

    __array_priority__ = 10000

其中IdentityMatrix知道,当作用于左侧或右侧的矩阵时,它什么也不做。我希望将__array_priority__设置为较高的数字会导致当另一个参数是ndarray时,它总是会覆盖ndarray的矩阵乘法,但这是行不通的:

In [2]: A = IdentityMatrix
   ...: B = np.array([[3,4],[6,1]])
   ...: A @ B
   ...: 
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-8f0beccd3079> in <module>()
      1 A = IdentityMatrix
      2 B = np.array([[3,4],[6,1]])
----> 3 A @ B

ValueError: matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)

反转输入会产生相同的错误。令人沮丧的是,这样做确实seem to work for operations other than matmul

有什么简单的方法可以覆盖我想要的matmul吗?

谢谢!

0 个答案:

没有答案