如何安全地子类化ndarray并获得与ndarray一致的行为 - 奇怪的nanmin / max结果?

时间:2016-11-14 22:04:58

标签: numpy

我正在尝试subclass an ndarray,以便我可以添加一些额外的字段。然而,当我这样做时,我在各种numpy函数中得到奇怪的行为。例如,nanmin返回现在返回一个我的新数组类型的对象,而之前我得到一个float64。为什么?这是nanmin还是我的班级的错误?

import numpy as np

class NDArrayWithColumns(np.ndarray):
    def __new__(cls, obj, columns=None):
        obj = obj.view(cls)
        obj.columns = tuple(columns)
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.columns = getattr(obj, 'columns', None)

NAN = float("nan")
r = np.array([1.,0.,1.,0.,1.,0.,1.,0.,NAN, 1., 1.])
print "MIN", np.nanmin(r), type(np.nanmin(r)) 

给出:

MIN 0.0 <type 'numpy.float64'>

>>> r = NDArrayWithColumns(r, ["a"])
>>> print "MIN", np.nanmin(r), type(np.nanmin(r))
MIN 0.0 <class '__main__.NDArrayWithColumns'>
>>> print r.shape 
(11,)

注意类型的变化,并且str(np.nanmin(r))显示1个字段,而不是11。

如果你感兴趣,我是子类,因为我想跟踪列名是单一类型的矩阵,但结构数组和记录类型数组允许不同的类型。)

1 个答案:

答案 0 :(得分:1)

根据docs,您需要实施在__array_wrap__ s结束时调用的ufunc方法:

    def __array_wrap__(self, out_arr, context=None):
        print('In __array_wrap__:')
        print('   self is %s' % repr(self))
        print('   arr is %s' % repr(out_arr))
        # then just call the parent
        return np.ndarray.__array_wrap__(self, out_arr, context)