numpy ndarray的子​​类没有按预期工作

时间:2014-05-19 10:09:43

标签: python numpy

`大家好。

我发现子类化ndarray时有一种奇怪的行为。

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

如您所见,keepdims参数对我的子类fooarray不起作用。它失去了一个轴。我怎能不避免这个问题?或者更一般地说,我怎样才能正确地将numpy ndarray子类化?

2 个答案:

答案 0 :(得分:4)

np.sum可以接受各种对象作为输入:不仅是ndarrays,还有列表,生成器,np.matrix等。 keepdims参数显然对列表或生成器没有意义。它也不适合np.matrix个实例,因为np.matrix总是有2个维度。如果您查看np.matrix.sum的通话签名,则会看到其sum方法没有keepdims参数:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

因此,ndarray的某些子类可能具有sum方法,这些方法没有keepdims参数。这是对Liskov substitution principle和你遇到的陷阱起源的不幸违反。

现在,如果你看一下the source code for np.sum,你会发现它是一个委托函数,试图根据第一个参数的类型确定要做什么。

如果第一个参数的类型不是ndarray,则会删除keepdims参数。这样做是因为将keepdims参数传递给np.matrix.sum会引发异常。

因为np.sum试图以最一般的方式进行委托,而不是对ndarray的子​​类可能采用的参数进行任何假设,所以在传递{keepdims时会丢弃fooarray参数1}}。

解决方法是不使用np.sum,而是调用a.sum。无论如何,这更直接,因为np.sum仅仅是一种委托功能。

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2

答案 1 :(得分:2)

详细说明@ mskimm的评论,如果你看看相关的 numpy的源代码core/fromnumeric.py的一部分,很明显为什么 a.sum(..., keepdims=True)有效,np.sum(a, ..., keepdims=True) 不:

def sum(a, axis=None, dtype=None, out=None, keepdims=False):
    ...
    if isinstance(a, _gentype):
        res = _sum_(a)
        if out is not None:
            out[...] = res
            return out
        return res
    elif type(a) is not mu.ndarray:
        try:
            sum = a.sum
        except AttributeError:
            return _methods._sum(a, axis=axis, dtype=dtype,
                                out=out, keepdims=keepdims)
        # NOTE: Dropping the keepdims parameters here...
        return sum(axis=axis, dtype=dtype, out=out)
    else:
        return _methods._sum(a, axis=axis, dtype=dtype,
                            out=out, keepdims=keepdims)
    ...

由于您已将[{1}}子类化,np.ndarraytype(a),而不是fooarray mu.ndarray,所以你最终会走到这一行:

# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)

keepdims关键字参数是ndarrays的一个相对较新的功能,目前尚未针对某些其他类似数组的类实现,例如np.matrixnp.ma.masked_array.sum()方法,因此非ndarray s当前该参数被删除的原因。