按索引复制numpy数组的单个轴

时间:2015-03-19 14:26:53

标签: python numpy

我正在寻找一种优雅的方法来通过索引提取numpy数组的单个轴的值。例如:

x = np.arange(16).reshape((4,4))
a = x[0]
b = x[:, 0] 

是我通常做的,但我正在寻找类似的东西:

a = get( x, axis=0, index=0)
b = get( x, axis=1, index=0)

这可能有一些奇特的功能吗?

2 个答案:

答案 0 :(得分:2)

您可以使用np.rollaxis将您感兴趣的轴移到前面,然后像往常一样将其索引:

def get(x, axis=0, index=0):
    return np.rollaxis(x, axis, 0)[index]

x = np.arange(27).reshape(3, 3, 3)

assert np.all(get(x, 1, 2) == x[:, 2, :])

正如Joe正确指出的那样,这会将视图返回到x。为了强制复制,您可以使用.copy()方法:

cpy = get(x, 1, 2).copy()

答案 1 :(得分:0)

您可以直接使用__getitem__魔术方法在功能上获得相同的功能,但更容易支持动态参数。以下是使用itertools包的示例:

def get(matrix, axis, index):
    return a.__getitem__(tuple(chain(repeat(slice(None), axis), (index,))))

这会创建一个带有切片对象的元组,该切片对象代表a[:]重复axis次的冒号,最后是index。我认为可以清理元组生成,但我现在想不出更干净的方式。

一个示例用法是:

a = np.arange(9).reshape(3, 3) # [[0 1 2], [3 4 5], [6 7 8]]
get(a, axis=0, index=0) # [0 1 2]
get(a, axis=1, index=0) # [0 3 6]
get(a, axis=0, index=1) # [3 4 5]
get(a, axis=1, index=1) # [1 4 7]