子类化numpy ndarray以在超出范围的索引

时间:2017-05-19 18:54:57

标签: python arrays numpy

我想要一个2D numpy数组(NxN),当索引为负数或超出范围时将返回零(即我想要抑制索引为-1或者当索引为-1时发生的常规环绕索引索引为N时的IndexError。我可以这样做,只需在数组周围添加一个零边框,并将其视为一个基于1的数组,而不是基于0的数组,但这似乎是不优雅的。

我偶然发现了一些关于子类化numpy ndarray类并定义自己的__getitem__属性的答案。我的第一次尝试看起来像这样:

import numpy as np

class zeroPaddedArray(np.ndarray):
    def __getitem__(self, index):
        x,y = index
        if x < 0 or y < 0 or x >= self.shape[0] or y >= self.shape[1]:
            return 0
        return super(zeroPaddedArray, self).__getitem__(index)

这种工作,但只允许您以arr[x,y]访问数组元素,并在您尝试arr[x][y]时抛出错误。它还完全打破了许多其他功能,例如printprint arr给出TypeError: 'int' object is not iterable

我的下一次尝试是检查是否为索引提供了元组,如果没有,则默认为旧的行为。

import numpy as np

class zeroPaddedArray(np.ndarray):
    def __getitem__(self, index):
        if type(index) is tuple:
            x,y = index
            if x < 0 or y < 0 or x >= self.shape[0] or y >= self.shape[1]:
                return 0
        return super(zeroPaddedArray, self).__getitem__(index)
    else:
        return super(zeroPaddedArray, self).__getitem__(index)

当我将索引作为元组(arr[-1,-1]正确地给出0)时,这给了我所需的零填充行为,但允许其他函数正常工作。但是,现在我得到不同的结果取决于我索引事物的方式。例如:

a = np.ones((5,5))
b = a.view(zeroPaddedArray)
print b[-1][-1]
print b[-1,-1]

给出

>>>1.0
>>>0

我认为这可能是出于我的目的,但我不满意。无论我使用哪种语法进行索引,如何在不破坏所有其他ndarray功能的情况下调整它以提供所需的零填充行为?

1 个答案:

答案 0 :(得分:1)

以下是take的使用方式:

In [34]: a=np.zeros((5,5),int)
In [35]: a[1:4,1:4].flat=np.arange(9)
In [36]: a
Out[36]: 
array([[0, 0, 0, 0, 0],
       [0, 0, 1, 2, 0],
       [0, 3, 4, 5, 0],
       [0, 6, 7, 8, 0],
       [0, 0, 0, 0, 0]])
In [37]: np.take(a, np.arange(-1,6),1,mode="clip")
Out[37]: 
array([[0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 2, 0, 0],
       [0, 0, 3, 4, 5, 0, 0],
       [0, 0, 6, 7, 8, 0, 0],
       [0, 0, 0, 0, 0, 0, 0]])
In [38]: np.take(a, np.arange(-1,6),0,mode="clip")
Out[38]: 
array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 1, 2, 0],
       [0, 3, 4, 5, 0],
       [0, 6, 7, 8, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]])

我使用clip模式扩展0边界。

还有一个np.pad函数,虽然它的所有通用性很长,而且没有速度解决方案(它最终会在每个维度上进行2个连接)。

np.lib.index_tricks.py有一些使用自定义类来播放索引技巧的好例子。

在深入研究ndarray子类之前,我建议编写函数或技巧类来测试你的想法。