防止numpy矢量化乘法

时间:2016-10-17 23:31:39

标签: python numpy

我创建了一组用于处理度和弧度的小助手类。我已经包含了一些numpy ufuncs(弧度,度,sin,cos ......),这样我就可以在numpy数组中拥有deg个对象,并在其上执行numpy trig操作(例如,np.cos(np.array([5*deg, 10*deg, 15*deg]))) numpy数组。

但是,我发现当在RHS上将ndarray“乘以”deg类时,会在数组上调用numpy对象的__mul__方法,而不是{{ 1}}导致UnitMeta.__rmul__根据需要提升TypeError。当NumberMixin.__new__是LHS时,它可以正常工作(引发错误),例如deg

为简洁起见,下面仅显示了部分类别。

deg * np.array([1])

其他一切运作良好。

我想阻止numpy这样做,以便我可以写更像“普通数学”的东西,例如:

'deg.py'
from numbers import Number
import numpy as np

class UnitMeta(type):
    def __mul__(cls, other):
        return cls(other)
    def __rmul__(cls, other):
        '''So can write things like "1 * deg"'''
        return cls(other)

class NumberMixin():
    def __new__(cls, v):
        if not isinstance(v, Number):
            raise TypeError('A valid numeric type value is required. {} is not numeric.'.format(type(v)))
        return super().__new__()
    def __mul__(self, other):
        return self._v * other
    def __rmul__(self, other):
        return self.__mul__(other)
    def radians(self): # NOTE: overridden in deg below 
        return np.radians(self._v)
    def degrees(self): # NOTE: overridden in deg below 
        return np.degrees(self._v)
    def sin(self):
        return np.sin(self._v)
    def cos(self):
        return np.cos(self._v)
    def tan(self):
        return np.tan(self._v)

class deg(NumberMixin, metaclass = UnitMeta):
    def __init__(self, d = 0.0):
        if isinstance(d, deg):
            self._v = d._v
            self._deg = d._deg
        else:
            self._v = np.radians(d)
            self._deg = d
    def __mul__(self, other):
        if isinstance(type(other),UnitMeta):
            return NotImplemented
        else:
            return super().__mul__(other)
    def __str__(self):
        return str(self._deg) + '°'
    def __repr__(self):
        return str('deg({})'.format(self._deg))
    def __format__(self, spec):
        return self._deg.__format__(spec) + '°'
    def degrees(self):
        return self
    def radians(self):
        return self._v


if __name__ == '__main__':
    try:
        print('FAILURE: ', deg * np.array([1]) , ' exception not caught')
    except TypeError:
        print('SUCCESS: deg * np.array([1]) exception caught.')
    try:
        print('FAILURE: ', np.array([1]) * deg, ' exception not caught')
    except TypeError:
        print('SUCCESS: np.array([1]) * deg exception caught.')        

...但是当出现意外情况时,会出现异常,并且不会被吞下:

5 * deg

有什么建议吗?到目前为止,我还没有和numpy合作过,所以如果有明显的解决方案,请道歉。

0 个答案:

没有答案