避免numpy为重载的运算符

时间:2016-03-08 11:57:15

标签: python numpy operator-overloading

默认情况下,如果numpy不知道另一个对象的类型,则会跨数组分配操作。这在大多数情况下效果很好。例如,以下行为符合预期。

np.arange(5) + 5 # = [5, 6, 7, 8, 9]

我想定义一个覆盖加法运算符的类,如下面的代码所示。

class Example:
    def __init__(self, value):
        self.value = value

    def __add__(self, other):
        return other + self.value

    def __radd__(self, other):
        return other + self.value

它适用于标量值。例如,

np.arange(5) + Example(5) # = [5, 6, 7, 8, 9]

然而,它并没有完全符合我对矢量值的要求。例如,

np.arange(5) + Example(np.arange(5)) 

产生输出

array([array([0, 1, 2, 3, 4]), array([1, 2, 3, 4, 5]),
   array([2, 3, 4, 5, 6]), array([3, 4, 5, 6, 7]),
   array([4, 5, 6, 7, 8])], dtype=object)

因为前面的numpy数组的__add__运算符优先于我定义的__radd__运算符。 Numpy的__add__运算符为numpy数组的每个元素调用__radd__,产生一个数组数组。如何避免numpy分发操作?我想避免子类化numpy数组。

1 个答案:

答案 0 :(得分:2)

对于每个np.ndarray和不太急切的子类(例如在np.ma.MaskedArray早期的numpy版本中忽略它),即使你不是子类,也可以定义__array_priority__直接np.ndarray

这背后的想法很简单:具有较高优先级的子类决定哪个运算符定义数学运算而不是运算的顺序。

与您Example合作的一个工作示例是:

class Example:

    # Define this priority
    __array_priority__ = 2

    def __init__(self, value):
        self.value = value

    def __add__(self, other):
        return other + self.value

    def __radd__(self, other):
        return other + self.value


import numpy as np
np.arange(5) + Example(np.arange(5)) 
# returns array([0, 2, 4, 6, 8])

所以它按照需要运作。但请注意,依靠这种方法存在一些微妙的问题:

它不适用于MaskedArrays,因为它们的优先级为15(因此您需要将优先级更改为16+以使其正常工作):

import numpy as np
np.ma.array(np.arange(5)) + Example(np.arange(5)) 

# returns:
masked_array(data = [array([0, 1, 2, 3, 4]) array([1, 2, 3, 4, 5])    array([2, 3, 4, 5, 6])
array([3, 4, 5, 6, 7]) array([4, 5, 6, 7, 8])],
         mask = False,
   fill_value = ?)

例如,它不能与astropy.units.Quantity一起使用,因为他们已将其优先级定义为10000

import astropy.units as u
(np.arange(5)*u.dimensionless_unscaled) + Example(np.arange(5)) 
#returns:
<Quantity [array([ 0.,  1.,  2.,  3.,  4.]),
           array([ 1.,  2.,  3.,  4.,  5.]),
           array([ 2.,  3.,  4.,  5.,  6.]),
           array([ 3.,  4.,  5.,  6.,  7.]),
           array([ 4.,  5.,  6.,  7.,  8.])]>

它不适用于任何不使用numpy - 机器的课程。