在子类化numpy ndarray时,如何正确修改__getitem__?

时间:2015-07-08 02:57:36

标签: python numpy multidimensional-array subclass slice

我试图将numpy的ndarray子类化。在我的子类MyClass中,我已经将一个名为time的字段添加为主数据的并行数组。

我的目标如下:假设我创建了一个MyClass实例,让我们称之为mc。 我切片mc,例如mc[2:6],我希望生成的对象不仅包含正确切片的np数组,还包含相应切片的time数组。

这是我的尝试:

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
    def __getitem__(self, item):
        #print item #for testing
        ret = super(MyClass, self).__getitem__(item)
        ret.time = self.time.__getitem__(item)
        return ret

这不起作用。经过几个小时的捣乱,我意识到这是因为当我呼叫mc[2:6]时,__getitem__实际上被多次调用。首先,当它被调用时,item变量,如预期的那样,是slice(2,6,None)。但是,包含super(MyClass, self)...的行再次调用相同的函数,可能是为了检索切片的各个元素。

问题在于它为__getitem__提供了一组奇怪的参数,总是负数。在mc[2:6]的示例中,它更多次调用方法4,item值为-4,-3,-2和-1。

正如您所看到的,这使我无法正确调整ret.time变量,因为它会尝试多次修改它,通常使用无意义的索引。

我尝试过多种方式解决这个问题,包括复制对象和编辑该副本,获取对象的各种视图以及许多其他黑客,但似乎没有人能够解决__getitem__重复的问题使用负索引调用,不与请求的切片对齐。

对于正在发生的事情的任何帮助或解释将不胜感激。

3 个答案:

答案 0 :(得分:2)

我有一个类似的问题,我以numpy matrix类为例解决了。正如您在__getitem__中创建数组之前注意到的那样,可以多次调用__array_finalize__。因此解决方案是将潜在的新索引存储在__getitem__中,但将其设置在__array_finalize__中。

class MyClass(np.ndarray):
    def __new__(cls, data, time=None):
        obj = np.asarray(data).view(cls)
        obj.time = time
        return obj
    def __array_finalize__(self, obj):
        setattr(self, 'time', obj.time)
        try:
            self.time = self.time[obj._new_time_index]
        except:
            pass

    def __getitem__(self, item):
        try:
            if isinstance(item, (slice, int)):
                self._new_time_index = item
            else:
                self._new_time_index = item[0]
        except: 
            pass
        return super().__getitem__(item)

答案 1 :(得分:0)

如果您想更新切片上的time,请尝试

if isinstance(item, slice):
    ret.time = self.time.__getitem__(item)

__getitem__方法中。

然后你的time调整代码每次切片只调用一次,从数组中获取单个项目时从不执行。

答案 2 :(得分:0)

我解决问题的方式(试图做类似的事情)如下:

class MyClass(np.ndarray):
    ...

    def __getitem__(self, item):
        #print item #for testing
        ret = super(MyClass, self).__getitem__(item)
        if not isinstance(self, MyClass):
            return ret

        ret.time = self.time.__getitem__(item)
        return ret

这样,如果__getitem__被多次调用,则您只会在调用实例为time的第一个调用上调整MyClass方法。