Python Numpy中的Array和__rmul__运算符

时间:2016-07-06 17:19:08

标签: python arrays numpy

在一个项目中,我创建了一个类,我需要在这个新类和一个真实矩阵之间进行操作,所以我重载了这个__rmul__函数

class foo(object):

    aarg = 0

    def __init__(self):
        self.aarg = 1


    def __rmul__(self,A):
        print(A)
        return 0

    def __mul__(self,A):
        print(A)
        return 0

但是当我打电话给它时,结果并不是我所期待的

A = [[i*j for i in np.arange(2)  ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C =  A * R

输出:

0
0
0
1
0
2

似乎该函数被调用了6次,每个元素一次。

相反,__mul__函数非常有用

C = R * A

输出:

[[0 0]
 [0 1]
 [0 2]]

如果A不是一个数组,而只是一个列表列表,那么两者都可以正常工作

A = [[i*j for i in np.arange(2)  ] for j in np.arange(3)]
R = foo()
C =  A * R
C = R * A

输出

[[0, 0], [0, 1], [0, 2]]
[[0, 0], [0, 1], [0, 2]]

我真的希望我的__rmul__函数也可以在数组上工作(我的原始乘法函数不是可交换的)。我该如何解决?

3 个答案:

答案 0 :(得分:6)

行为是预期的。

首先,你必须了解像x*y这样的操作是如何实际执行的。 python解释器将首先尝试计算x.__mul__(y)。 如果此调用返回NotImplemented,则然后会尝试计算y.__rmul__(x) ,当yx类型的正确子类时,在这种情况下,解释程序将首先考虑y.__rmul__(x)然后x.__mul__(y)。< / p>

现在发生的事情是numpy根据他是否认为参数是标量或数组来区别对待参数。

当处理数组*时,逐元素乘法,而标量乘法将数组的所有条目乘以给定的标量。

在你的情况下,foo()被numpy视为标量,因此numpy将数组的所有元素乘以foo。此外,由于numpy不知道类型foo,它返回一个带dtype=object的数组,因此返回的对象是:

array([[0, 0],
       [0, 0],
       [0, 0]], dtype=object)

注意:当您尝试计算产品时,numpy的数组返回NotImplemented,因此解释程序会调用n​​umpy的数组{ {1}}方法,如我们所说,执行标量乘法。在这一点上,numpy将尝试将数组的每个条目乘以&#34;标量&#34; __mul__,这里是调用foo()方法的地方,因为当__rmul__调用NotImplemented时,数组中的数字会返回__mul__参数。

显然,如果您将参数的顺序更改为初始乘法,则会立即调用foo方法,并且您不会遇到任何问题。

因此,要回答您的问题,处理此问题的一种方法是让__mul__继承foo,以便适用第二条规则:

ndarray

警告subclassing ndarray isn't straightforward。 此外,您可能还有其他副作用,因为现在您的课程是class foo(np.ndarray): def __new__(cls): # you must implement __new__ # code as before

答案 1 :(得分:3)

我无法像Bakuriu那样准确地解释潜在的问题,但可能还有另一种解决方案。

您可以通过定义__array_priority__强制numpy使用您的评估方法。正如numpy docs中的here所述。

在您的情况下,您必须将班级定义更改为:

MAGIC_NUMBER = 15.0
# for the necessary lowest values of MAGIC_NUMBER look into the numpy docs
class foo(object):
    __array_priority__ = MAGIC_NUMBER
    aarg = 0

    def __init__(self):
        self.aarg = 1


    def __rmul__(self,A):
        print(A)
        return 0

    def __mul__(self,A):
        print(A)
        return 0

答案 2 :(得分:3)

您可以在班级中定义__numpy_ufunc__功能。它甚至可以无需子类化 np.ndarray。您可以找到文档here

以下是基于您的案例的示例:

class foo(object):

    aarg = 0

    def __init__(self):
        self.aarg = 1

    def __numpy_ufunc__(self, *args):
        pass

    def __rmul__(self,A):
        print(A)
        return 0

    def __mul__(self,A):
        print(A)
        return 0

如果我们尝试,

A = [[i*j for i in np.arange(2)  ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C =  A * R

输出:

[[0 0]
 [0 1]
 [0 2]]

有效!