按照说明here我创建了一个ndarray的子类,它为ndarray类添加了新属性。现在我想为新类定义一个比较运算符,除了比较数据之外,还要比较属性的值。所以我尝试了这个:
def __eq__(self, other):
return (self._prop1 == other._prop1) and \
(self._prop2 == other._prop2) and \
(self.data == other.data)
这允许比较T1 == T2
并返回布尔值。但是,由于我想将这些数组与其他ndarrays交换使用,我希望比较返回一个布尔数组。如果我没有定义我的__eq__
函数,那么比较会返回一个布尔数组,但是我无法检查属性。我该如何将两者合并?
答案 0 :(得分:1)
根据suggestion by hpaulj,我通过查看np.ma.core.MaskedArray.__eq__
找出了如何做到这一点。这是参考的最小实现。主要思想是在基类类型__eq__()
中的self
视图上调用numpy DerivedArray
。
class DerivedArray(np.ndarray):
def __new__(cls, input_array, prop1, prop2):
_baseclass = getattr(input_array, '_baseclass', type(input_array))
obj = np.asarray(input_array).view(cls)
obj._prop1 = prop1
obj._prop2 = prop2
obj._baseclass = _baseclass
return obj
def __array_finalize__(self, obj):
if obj is None:
return
else:
if not isinstance(obj, np.ndarray):
_baseclass = type(obj)
else:
_baseclass = np.ndarray
self._prop1 = getattr(obj, '_prop1', None)
self._prop2 = getattr(obj, '_prop2', None)
self._baseclass= getattr(obj, '_baseclass', _baseclass)
def _get_data(self):
"""Return the current data, as a view of the original
underlying data.
"""
return np.ndarray.view(self, self._baseclass)
_data = property(fget=_get_data)
data = property(fget=_get_data)
def __eq__(self, other):
attsame = (self._prop1 == other._prop1) and (self._prop2 == other._prop2)
if not attsame: return False
return self._data.__eq__(other)