带有__add__的自定义类,可以使用NumPy数组添加

时间:2017-08-29 19:40:45

标签: python numpy operator-overloading

我有一个自定义类,实现__add__和__radd__为

import numpy

class Foo(object):

    def __init__(self, val):
        self.val = val

    def __add__(self, other):
        print('__add__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return self.val + other

    def __radd__(self, other):
        print('__radd__')
        print('type self = %s' % type(self))
        print('type other = %s' % type(other))
        return other + self.val

我首先测试__add __

r1 = Foo(numpy.arange(3)) + numpy.arange(3,6)
print('type results = %s' % type(r1))
print('result = {}'.format(r1))

它会导致预期的结果

>>> __add__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'numpy.ndarray'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [3  5  7]

然而,测试__radd __

r2 = numpy.arange(3) + Foo(numpy.arange(3,6))
print('type results = %s' % type(r2))
print('result = {}'.format(r2))

我得到了

>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> __radd__
>>> type self = <class '__main__.Foo'>
>>> type other = <type 'int'>
>>> type results = <type 'numpy.ndarray'>
>>> result = [array([3, 4, 5]) array([4, 5, 6]) array([5, 6, 7])]

这对我没有任何意义。 NumPy是否为任意对象重载__add__,然后优先于我的__radd__?如果是的话,为什么他们会做这样的事情?另外,我怎么能避免这种情况,我真的希望能够在左侧添加带有NumPy数组的自定义类。感谢。

1 个答案:

答案 0 :(得分:2)

这在评论中隐藏了,但应该是答案。

默认情况下,Numpy操作以每个元素为基础,以任意对象为对象,然后尝试按元素执行操作(根据广播规则)。

例如,这意味着给定

class N:
    def __init__(self, x):
        self.x = x

    def __add__(self, other):
        return self.x + other

    def __radd__(self, other):
        return other + self.x

由于Python运算符解析

N(3) + np.array([1, 2, 3])

将以__add__到达上面的N(3),并将整个数组作为other一次,然后执行常规的Numpy加法。

另一方面

np.array([1, 2, 3]) + N(3)

将成功输入Numpy的ufunc(在这种情况下为运算符),因为它们将任意对象作为“其他”,然后尝试依次执行:

1 + N(3)
2 + N(3)
3 + N(3)

这意味着上面的__add__调用3次,而不是一次,每个元素调用一次,从而大大降低了操作速度。要禁用此行为,并在获取Numpy对象时让NotImplementedError引发N,从而允许RHS重载radd接管,请在您的正文中添加以下内容课:

class N:
    ...
    __numpy_ufunc__ = None # Numpy up to 13.0
    __array_ufunc__ = None # Numpy 13.0 and above

如果不是向后兼容的问题,那么只需要第二个。