n轴沿着2轴“拿”

时间:2017-04-29 05:28:25

标签: python numpy indexing

我有一个3D数组a数据和一个2D数组b索引。我需要使用a中的索引沿第3轴采用b的子数组。我可以使用take这样做:

a = np.arange(24).reshape((2,3,4))
b = np.array([0,2,1,3]).reshape((2,2))
np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])

我可以在没有列表理解的情况下使用一些花哨的索引吗?我担心效率,所以如果在这种情况下花式索引不是更有效,我想知道它。

编辑我尝试的第一件事是a[[0,1],:,b],但它没有提供我需要的子数组

2 个答案:

答案 0 :(得分:2)

In [317]: a
Out[317]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
In [318]: a = np.arange(24).reshape((2,3,4))
     ...: b = np.array([0,2,1,3]).reshape((2,2))
     ...: np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])
     ...: 
Out[318]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

所以你想要0&第1块的2列,1& 3从第二个。

制作一个与c形状匹配的b,并体现这种观察

In [319]: c=np.array([[0,0],[1,1]])
In [320]: c
Out[320]: 
array([[0, 0],
       [1, 1]])
In [321]: b
Out[321]: 
array([[0, 2],
       [1, 3]])

In [322]: a[c,:,b]
Out[322]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],

       [[13, 17, 21],
        [15, 19, 23]]])

这是正确的数字,但形状不正确。

可以使用列向量代替c

In [323]: a[np.arange(2)[:,None],:,b]  # or a[[[0],[1]],:,b]
Out[323]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],

       [[13, 17, 21],
        [15, 19, 23]]])

至于形状,我们可以转置最后两个轴

In [324]: a[np.arange(2)[:,None],:,b].transpose(0,2,1)
Out[324]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

这个转置是必需的,因为我们在两个索引数组之间有一个切片,它是基本索引和高级索引的混合。这是有记录的,但从来没有那么令人费解。它将切片尺寸(3)放在最后,我们必须将其转置回来。

很好的小索引拼图!

此高级/基本转置的最新问题和解释:

Indexing numpy multidimensional arrays depends on a slicing method

答案 1 :(得分:1)

这是我的第一次尝试。我会看看能不能做得更好。

#using numpy broadcasting.
np.r_[a[0][:,b[0]],a[1][:,b[1]]].reshape(2,3,2)
Out[300]: In [301]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

第二次尝试:

#convert both a and b to a 2d array and then slice all rows and only columns determined by b.
a.reshape(6,4)[np.arange(6)[:,None],b.repeat(3,0)].reshape(2,3,2)
Out[429]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])