numpy ndarray子类:ufunc不返回标量类型

时间:2013-10-07 11:40:06

标签: python numpy subclass scalar multidimensional-array

对于numpy.ndarray子类,ufunc输出具有相同的类型。这通常是好的,但我想用标量输出的ufunc返回标量类型(例如numpy.float64)。

示例:

import numpy as np

class MyArray(np.ndarray):
    def __new__(cls, array):
        obj = np.asarray(array).view(cls)
        return obj

a = MyArray(np.arange(5))
a*2
# MyArray([0, 2, 4, 6, 8])  => same class as original (i.e. MyArray), ok

a.sum()
# MyArray(10)               => same as original, but here I'd expect np.int64

type(2*a) is type(a.sum())
# True                    
b = a.view(np.ndarray)
type(2*b) is type(b.sum())    
# False

对于标准numpy数组,标量输出具有标量类型。那么如何为我的子类提供相同的行为?

我在OSX 10.6上使用Python 2.7.3和numpy 1.6.2

1 个答案:

答案 0 :(得分:3)

您需要使用如下所示的函数覆盖ndarray子类中的__array_wrap__

def __array_wrap__(self, obj):
    if obj.shape == ():
        return obj[()]    # if ufunc output is scalar, return it
    else:
        return np.ndarray.__array_wrap__(self, obj)
在ufuncs之后调用

__array_wrap__进行清理工作。在默认实现中,特殊情况是精确的ndarrays(但不是子类)将零排数组转换为标量。对于某些版本的numpy来说至少也是如此。