我试图将numpy的ndarray子类化。在我的子类MyClass
中,我已经将一个名为time
的字段添加为主数据的并行数组。
我的目标如下:假设我创建了一个MyClass实例,让我们称之为mc
。
我切片mc
,例如mc[2:6]
,我希望生成的对象不仅包含正确切片的np数组,还包含相应切片的time
数组。
这是我的尝试:
class MyClass(np.ndarray):
def __new__(cls, data, time=None):
obj = np.asarray(data).view(cls)
obj.time = time
return obj
def __array_finalize__(self, obj):
setattr(self, 'time', obj.time)
def __getitem__(self, item):
#print item #for testing
ret = super(MyClass, self).__getitem__(item)
ret.time = self.time.__getitem__(item)
return ret
这不起作用。经过几个小时的捣乱,我意识到这是因为当我呼叫mc[2:6]
时,__getitem__
实际上被多次调用。首先,当它被调用时,item
变量,如预期的那样,是slice(2,6,None)
。但是,包含super(MyClass, self)...
的行再次调用相同的函数,可能是为了检索切片的各个元素。
问题在于它为__getitem__
提供了一组奇怪的参数,总是负数。在mc[2:6]
的示例中,它更多次调用方法4,item
值为-4,-3,-2和-1。
正如您所看到的,这使我无法正确调整ret.time
变量,因为它会尝试多次修改它,通常使用无意义的索引。
我尝试过多种方式解决这个问题,包括复制对象和编辑该副本,获取对象的各种视图以及许多其他黑客,但似乎没有人能够解决__getitem__
重复的问题使用负索引调用,不与请求的切片对齐。
对于正在发生的事情的任何帮助或解释将不胜感激。
答案 0 :(得分:2)
我有一个类似的问题,我以numpy matrix类为例解决了。正如您在__getitem__
中创建数组之前注意到的那样,可以多次调用__array_finalize__
。因此解决方案是将潜在的新索引存储在__getitem__
中,但将其设置在__array_finalize__
中。
class MyClass(np.ndarray):
def __new__(cls, data, time=None):
obj = np.asarray(data).view(cls)
obj.time = time
return obj
def __array_finalize__(self, obj):
setattr(self, 'time', obj.time)
try:
self.time = self.time[obj._new_time_index]
except:
pass
def __getitem__(self, item):
try:
if isinstance(item, (slice, int)):
self._new_time_index = item
else:
self._new_time_index = item[0]
except:
pass
return super().__getitem__(item)
答案 1 :(得分:0)
如果您想更新切片上的time
,请尝试
if isinstance(item, slice):
ret.time = self.time.__getitem__(item)
在__getitem__
方法中。
然后你的time
调整代码每次切片只调用一次,从数组中获取单个项目时从不执行。
答案 2 :(得分:0)
我解决问题的方式(试图做类似的事情)如下:
class MyClass(np.ndarray):
...
def __getitem__(self, item):
#print item #for testing
ret = super(MyClass, self).__getitem__(item)
if not isinstance(self, MyClass):
return ret
ret.time = self.time.__getitem__(item)
return ret
这样,如果__getitem__
被多次调用,则您只会在调用实例为time
的第一个调用上调整MyClass
方法。