结合多维numpy数组的切片和广播索引

时间:2014-12-04 15:14:18

标签: python arrays numpy multidimensional-array slice

我有一个ND numpy数组(比方说3x3x3),我想提取一个子数组,结合切片和索引数组。例如:

import numpy as np  
A = np.arange(3*3*3).reshape((3,3,3))
i0, i1, i2 = ([0,1], [0,1,2], [0,2])
ind1 = j0, j1, j2 = np.ix_(i0, i1, i2)
ind2 = (j0, slice(None), j2)
B1 = A[ind1]
B2 = A[ind2]

我希望B1 == B2,但实际上,形状是不同的

>>> B1.shape
(2, 3, 2)
>>> B2.shape
(2, 1, 2, 3)
>>> B1
array([[[ 0,  2],
        [ 3,  5],
        [ 6,  8]],

       [[ 9, 11],
        [12, 14],
        [15, 17]]])
>>> B2
array([[[[ 0,  3,  6],
         [ 2,  5,  8]]],

       [[[ 9, 12, 15],
         [11, 14, 17]]]])

有人理解为什么?有没有想过如何通过操纵'A'和'ind2'对象来获得'B1'?目标是它适用于任何nD阵列,并且我不必寻找我想完全保留的尺寸形状(希望我足够清楚:))。谢谢!
---编辑---
为了更清楚,我希望有一个“有趣”的功能

A[fun(ind2)] == B1

3 个答案:

答案 0 :(得分:0)

ind1的索引子空间是(2,),(3,),(2,),结果B(2,3,2)。这是高级索引的简单案例。

ind2是(高级)部分索引的一种情况。有2个索引数组和1个切片。高级索引文档声明:

  

如果索引子空间是分开的(通过切片对象),则首先是广播的索引空间,然后是x的切片子空间。

在这种情况下,高级索引构造(2,2)数组(来自第1和第3个索引),并在最后附加切片维度,从而生成(2,2,3)数组。

我在https://stackoverflow.com/a/27097133/901925

中更详细地解释了推理

修复像ind2这样的元组的方法是将每个切片扩展为数组。我最近在np.insert中看到了这一点。

np.arange(*ind2[1].indices(3))

:扩展为[0,1,2]。但替换必须具有正确的形状。

ind=list(ind2)
ind[1]=np.arange(*ind2[1].indices(3)).reshape(1,-1,1)
A[ind]

我正在忽略确定哪个术语是切片,其尺寸和相关重塑的细节。目标是重现i1

如果索引是由ix_之外的其他内容生成的,那么重塑此切片可能会更困难。例如

A[np.array([0,1])[None,:,None],:,np.array([0,2])[None,None,:]] # (1,2,2,3)
A[np.array([0,1])[None,:,None],np.array([0,1,2])[:,None,None],np.array([0,2])[None,None,:]]
# (3,2,2)

扩展切片必须与广播下的其他阵列兼容。

索引后交换轴是另一种选择。但是,逻辑可能更复杂。 但在某些情况下,转置可能实际上更简单:

A[np.array([0,1])[:,None],:,np.array([0,2])[None,:]].transpose(2,0,1)
# (3,2,2)
A[np.array([0,1])[:,None],:,np.array([0,2])[None,:]].transpose(0,2,1)
# (2, 3, 2)

答案 1 :(得分:0)

我越接近您的规格,我就无法设计出一种能够在不知道A的情况下计算正确指数的解决方案(或者更准确地说,它的形状......) )。

import numpy as np  

def index(A, s):
    ind = []
    groups = s.split(';')
    for i, group in enumerate(groups):
        if group == ":":
            ind.append(range(A.shape[i]))
        else:
            ind.append([int(n) for n in group.split(',')])
    return np.ix_(*ind)

A = np.arange(3*3*3).reshape((3,3,3))

ind2 = index(A,"0,1;:;0,2")
print A[ind2]

较短的版本

def index2(A,s):return np.ix_(*[range(A.shape[i])if g==":"else[int(n)for n in g.split(',')]for i,g in enumerate(s.split(';'))])

ind3 = index2(A,"0,1;:;0,2")
print A[ind3]

答案 2 :(得分:0)

在使用ix_这样的限制索引案例中,可以在连续的步骤中进行索引。

A[ind1]

相同
A[i1][:,i2][:,:,i3]

,因为i2是全范围,

A[i1][...,i3]

如果您只有ind2可用

A[ind2[0].flatten()][[ind2[2].flatten()]

在更一般的情况下,您必须知道j0,j1,j2如何相互广播,但当它们由ix_生成时,关系很简单。

我可以想象分配A1 = A[i1]方便的情况,然后是涉及A1的各种行动,包括但不限于A1[...,i3]。您必须知道A1何时是一个视图,何时是一个副本。

另一个索引工具是take

A.take(i0,axis=0).take(i2,axis=2)