`大家好。
我发现子类化ndarray时有一种奇怪的行为。
import numpy as np
class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array).view(cls)
return obj
def __init__(self, *args, **kwargs):
return
def __array_finalize__(self, obj):
return
a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)
a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)
print a_sum.ndim #1
print b_sum.ndim #2
如您所见,keepdims
参数对我的子类fooarray
不起作用。它失去了一个轴。我怎能不避免这个问题?或者更一般地说,我怎样才能正确地将numpy ndarray子类化?
答案 0 :(得分:4)
np.sum
可以接受各种对象作为输入:不仅是ndarrays,还有列表,生成器,np.matrix
等。 keepdims
参数显然对列表或生成器没有意义。它也不适合np.matrix
个实例,因为np.matrix
总是有2个维度。如果您查看np.matrix.sum
的通话签名,则会看到其sum
方法没有keepdims
参数:
Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)
因此,ndarray
的某些子类可能具有sum
方法,这些方法没有keepdims
参数。这是对Liskov substitution principle和你遇到的陷阱起源的不幸违反。
现在,如果你看一下the source code for np.sum
,你会发现它是一个委托函数,试图根据第一个参数的类型确定要做什么。
如果第一个参数的类型不是ndarray
,则会删除keepdims
参数。这样做是因为将keepdims参数传递给np.matrix.sum
会引发异常。
因为np.sum
试图以最一般的方式进行委托,而不是对ndarray的子类可能采用的参数进行任何假设,所以在传递{keepdims
时会丢弃fooarray
参数1}}。
解决方法是不使用np.sum
,而是调用a.sum
。无论如何,这更直接,因为np.sum
仅仅是一种委托功能。
import numpy as np
class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array, *args, **kwargs).view(cls)
return obj
a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)
a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)
print(a_sum.ndim) # 2
print(b_sum.ndim) # 2
答案 1 :(得分:2)
详细说明@ mskimm的评论,如果你看看相关的
numpy的源代码core/fromnumeric.py
的一部分,很明显为什么
a.sum(..., keepdims=True)
有效,np.sum(a, ..., keepdims=True)
不:
def sum(a, axis=None, dtype=None, out=None, keepdims=False):
...
if isinstance(a, _gentype):
res = _sum_(a)
if out is not None:
out[...] = res
return out
return res
elif type(a) is not mu.ndarray:
try:
sum = a.sum
except AttributeError:
return _methods._sum(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims)
# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)
else:
return _methods._sum(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims)
...
由于您已将[{1}}子类化,np.ndarray
为type(a)
,而不是fooarray
mu.ndarray
,所以你最终会走到这一行:
# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)
keepdims
关键字参数是ndarrays
的一个相对较新的功能,目前尚未针对某些其他类似数组的类实现,例如np.matrix
或np.ma.masked_array
有.sum()
方法,因此非ndarray
s当前该参数被删除的原因。