什么(numpy)__array_wrap__呢?

时间:2015-07-23 15:05:05

标签: python numpy

我第一次潜入SciPy LinAlg模块,我看到了这个功能:

def _makearray(a):
    new = asarray(a)
    wrap = getattr(a, "__array_prepare__", new.__array_wrap__)
    return new, wrap

__array_wrap__完全做了什么?我找到了documentation,但我不明白这个解释:

 At the end of every ufunc, this method is called on the input object with the
 highest array priority, or the output object if one was specified. The ufunc-
 computed array is passed in and whatever is returned is passed to the user. 
 Subclasses inherit a default implementation of this method, which transforms the
 array into a new instance of the object’s class. Subclasses may opt to use this
 method to transform the output array into an instance of the subclass and update
 metadata before returning the array to the user.

这是否意味着它会将任何函数的输出重新转换回array,因为它可能会被逐个元素处理分解成其他内容?相关地,无论解释如何,将此wrap作为对象意味着什么?你会用它做什么?

我正在查看numpy.linalg.inv的代码...这里的包装是什么?

    **a, wrap = _makearray(a)**
    _assertRankAtLeast2(a)
    _assertNdSquareness(a)
    t, result_t = _commonType(a)

    if a.shape[-1] == 0:
        # The inner array is 0x0, the ufunc cannot handle this case
        **return wrap(empty_like(a, dtype=result_t))**

    signature = 'D->D' if isComplexType(t) else 'd->d'
    extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
    ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
    return wrap(ainv.astype(result_t))

1 个答案:

答案 0 :(得分:4)

np.ma.masked_array.__array_wrap__是更新元数据(mask)的数组子类的示例。

File:        /usr/lib/python3/dist-packages/numpy/ma/core.py
Definition:  np.ma.masked_array.__array_wrap__(self, obj, context=None)
Source:
    def __array_wrap__(self, obj, context=None):
        """
        Special hook for ufuncs.
        Wraps the numpy array and sets the mask according to context.
        """

np.matrix.__array_wrap__似乎继承了ndarray版本。我的猜测是因为matrix,而子类,没有需要更新的元数据。

一般来说,hook的想法是,它是在正常处理中深度调用的函数。默认方法可能不会执行任何操作。但这是子类可以采取特殊行动的一种方式。类开发人员编写这样的钩子,以便类用户不必担心这些细节。使用__...__名称,它不是公共界面的一部分 - 尽管Python让我们在幕后巅峰。

wrapping的示例,即返回与输入具有相同类的数组:

In [659]: np.cumsum(np.arange(10))
Out[659]: array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32)

In [660]: np.cumsum(np.matrix(np.arange(10)))
Out[660]: matrix([[ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45]], dtype=int32

In [665]: np.cumsum(np.ma.masked_array(np.arange(10)))
Out[665]: 
masked_array(data = [ 0  1  3  6 10 15 21 28 36 45],
             mask = False,
       fill_value = 999999)

返回的值完全相同,但数组子类因输入类而异。

cumsum可能不是最好的例子。蒙面数组具有自己的cumsum版本,其中一个将屏蔽值视为0

In [679]: m=np.ma.masked_array(np.arange(10),np.arange(10)%2)

In [680]: m
Out[680]: 
masked_array(data = [0 -- 2 -- 4 -- 6 -- 8 --],
             mask = [False  True False  True False  True False  True False  True],
       fill_value = 999999)

In [681]: np.cumsum(m)
Out[681]: 
masked_array(data = [0 -- 2 -- 6 -- 12 -- 20 --],
             mask = [False  True False  True False  True False  True False  True],
       fill_value = 999999)

add.accumulatecumsum类似,但没有特殊的屏蔽版本:

In [682]: np.add.accumulate(np.arange(10))
Out[682]: array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32)

In [683]: np.add.accumulate(m)
Out[683]: 
masked_array(data = [ 0  1  3  6 10 15 21 28 36 45],
             mask = False,
       fill_value = 999999)

这最后是一个蒙面数组,但是掩码是默认的False,并且掩码值包含在总和中。