比较ndarray派生类

时间:2015-12-03 18:51:40

标签: python numpy

按照说明here我创建了一个ndarray的子​​类,它为ndarray类添加了新属性。现在我想为新类定义一个比较运算符,除了比较数据之外,还要比较属性的值。所以我尝试了这个:

def __eq__(self, other):
    return (self._prop1 == other._prop1) and \
           (self._prop2 == other._prop2) and \
           (self.data == other.data)

这允许比较T1 == T2并返回布尔值。但是,由于我想将这些数组与其他ndarrays交换使用,我希望比较返回一个布尔数组。如果我没有定义我的__eq__函数,那么比较会返回一个布尔数组,但是我无法检查属性。我该如何将两者合并?

1 个答案:

答案 0 :(得分:1)

根据suggestion by hpaulj,我通过查看np.ma.core.MaskedArray.__eq__找出了如何做到这一点。这是参考的最小实现。主要思想是在基类类型__eq__()中的self视图上调用numpy DerivedArray

class DerivedArray(np.ndarray):
    def __new__(cls, input_array, prop1, prop2):       
        _baseclass = getattr(input_array, '_baseclass', type(input_array))
        obj = np.asarray(input_array).view(cls)

        obj._prop1    = prop1
        obj._prop2    = prop2
        obj._baseclass = _baseclass
        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
        else:
            if not isinstance(obj, np.ndarray):
                _baseclass = type(obj)
            else:
                _baseclass = np.ndarray

        self._prop1    = getattr(obj, '_prop1', None)
        self._prop2    = getattr(obj, '_prop2', None)
        self._baseclass= getattr(obj, '_baseclass', _baseclass)

    def _get_data(self):
        """Return the current data, as a view of the original
        underlying data.
        """
        return np.ndarray.view(self, self._baseclass)

    _data = property(fget=_get_data)
    data  = property(fget=_get_data)

    def __eq__(self, other):
        attsame = (self._prop1 == other._prop1) and (self._prop2 == other._prop2)
        if not attsame: return False
        return self._data.__eq__(other)