在一个项目中,我创建了一个类,我需要在这个新类和一个真实矩阵之间进行操作,所以我重载了这个__rmul__
函数
class foo(object):
aarg = 0
def __init__(self):
self.aarg = 1
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0
但是当我打电话给它时,结果并不是我所期待的
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C = A * R
输出:
0
0
0
1
0
2
似乎该函数被调用了6次,每个元素一次。
相反,__mul__
函数非常有用
C = R * A
输出:
[[0 0]
[0 1]
[0 2]]
如果A
不是一个数组,而只是一个列表列表,那么两者都可以正常工作
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
R = foo()
C = A * R
C = R * A
输出
[[0, 0], [0, 1], [0, 2]]
[[0, 0], [0, 1], [0, 2]]
我真的希望我的__rmul__
函数也可以在数组上工作(我的原始乘法函数不是可交换的)。我该如何解决?
答案 0 :(得分:6)
行为是预期的。
首先,你必须了解像x*y
这样的操作是如何实际执行的。 python解释器将首先尝试计算x.__mul__(y)
。
如果此调用返回NotImplemented
,则然后会尝试计算y.__rmul__(x)
。
,当y
是x
类型的正确子类时,在这种情况下,解释程序将首先考虑y.__rmul__(x)
然后x.__mul__(y)
。< / p>
现在发生的事情是numpy
根据他是否认为参数是标量或数组来区别对待参数。
当处理数组*
时,逐元素乘法,而标量乘法将数组的所有条目乘以给定的标量。
在你的情况下,foo()
被numpy视为标量,因此numpy将数组的所有元素乘以foo
。此外,由于numpy不知道类型foo
,它返回一个带dtype=object
的数组,因此返回的对象是:
array([[0, 0],
[0, 0],
[0, 0]], dtype=object)
注意:当您尝试计算产品时,numpy
的数组不返回NotImplemented
,因此解释程序会调用numpy的数组{ {1}}方法,如我们所说,执行标量乘法。在这一点上,numpy将尝试将数组的每个条目乘以&#34;标量&#34; __mul__
,这里是调用foo()
方法的地方,因为当__rmul__
调用NotImplemented
时,数组中的数字会返回__mul__
参数。
显然,如果您将参数的顺序更改为初始乘法,则会立即调用foo
方法,并且您不会遇到任何问题。
因此,要回答您的问题,处理此问题的一种方法是让__mul__
继承foo
,以便适用第二条规则:
ndarray
警告subclassing ndarray
isn't straightforward。
此外,您可能还有其他副作用,因为现在您的课程是class foo(np.ndarray):
def __new__(cls):
# you must implement __new__
# code as before
。
答案 1 :(得分:3)
我无法像Bakuriu那样准确地解释潜在的问题,但可能还有另一种解决方案。
您可以通过定义__array_priority__
强制numpy使用您的评估方法。正如numpy docs中的here所述。
在您的情况下,您必须将班级定义更改为:
MAGIC_NUMBER = 15.0
# for the necessary lowest values of MAGIC_NUMBER look into the numpy docs
class foo(object):
__array_priority__ = MAGIC_NUMBER
aarg = 0
def __init__(self):
self.aarg = 1
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0
答案 2 :(得分:3)
您可以在班级中定义__numpy_ufunc__
功能。它甚至可以无需子类化 np.ndarray
。您可以找到文档here。
以下是基于您的案例的示例:
class foo(object):
aarg = 0
def __init__(self):
self.aarg = 1
def __numpy_ufunc__(self, *args):
pass
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0
如果我们尝试,
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C = A * R
输出:
[[0 0]
[0 1]
[0 2]]
有效!