我有一个应用程序,在该应用程序中,我必须进行很多(成本很高)矩阵乘法。这些矩阵中有许多(但不是全部)具有特殊的结构,对于我来说,很容易定义自定义数据表示形式和自定义矩阵乘法例程。我希望能够将它们与表示为普通2d ndarrays的任意矩阵混合并匹配。
举一个最简单的例子,我想做这样的事情:
import numpy as np
class IdentityMatrix:
def __matmul__(self, other):
return other
def __rmatmul__(self, other):
return other
__array_priority__ = 10000
其中IdentityMatrix
知道,当作用于左侧或右侧的矩阵时,它什么也不做。我希望将__array_priority__
设置为较高的数字会导致当另一个参数是ndarray时,它总是会覆盖ndarray的矩阵乘法,但这是行不通的:
In [2]: A = IdentityMatrix
...: B = np.array([[3,4],[6,1]])
...: A @ B
...:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-2-8f0beccd3079> in <module>()
1 A = IdentityMatrix
2 B = np.array([[3,4],[6,1]])
----> 3 A @ B
ValueError: matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)
反转输入会产生相同的错误。令人沮丧的是,这样做确实seem to work for operations other than matmul。
有什么简单的方法可以覆盖我想要的matmul吗?
谢谢!