获取numpy中多维切片沿特定轴的索引

时间:2017-08-18 20:40:42

标签: python arrays numpy

我有这样的代码:

import numpy as np    
b = np.random.choice([0, 1], size=(12, 10, 2), p=[0.5, 0.5]) > 0.5    
a = np.ones((12, 10, 2, 6, 4))

a = a[b]    
print(a.shape)

我想知道每个选择来自轴1(上面的10)的哪个位置,例如,

a [0,0,0] = 0(来自b [:,0,:])

a [0,0,1] = 3(来自b [:,3,:])

a [6,3,1] = 1(来自b [:,1,:])

我该怎么做?

这是一个没有随机选择的简化版本:

import numpy as np

b = np.array([[0, 1], [1, 1]]) > 0.5
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

a = a[b] #gives [[3,4],[5,6],[7,8]]

# Desired result: 1,0,1 as each element came from that index of axis 1 of b

[1, 0, 1] # The index along the axis of the last dim of b for each selection in a

1 个答案:

答案 0 :(得分:1)

  

我想知道每个选择来自轴1(上面的10)的哪个位置,......

当您执行a = a[b]时,a的新元素将与您的随机True数组中的b值相关。为此,您可以使用numpy.where()上的b method来了解哪些内容,例如:

import numpy as np

b = np.random.choice([0, 1], size=(12, 10, 2), p=[0.5, 0.5]) > 0.5 #random choice
a = np.ones((12, 10, 2, 6, 4))

a = a[b] #obtain those randomly selected items
print(a.shape)

indexes = np.where(b==True)
print(indexes[1]) #the axis 1 you desire

请注意,如果您希望获得其他 b轴(例如轴i),您应该像indexes[i]那样得到它。另请注意,每次都会给出不同的值,因为它是随机的。

但是,使用更简单的示例对其进行测试,我们也会获得所需的[1,0,1]

import numpy as np
b = np.array([[0, 1], [1, 1]]) > 0.5
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

a = a[b] #gives result [[3,4], [5,6], [7,8]], so they are 1,0,1
print(a.shape) #gives (3, 2)

indexes = np.where(b==True)
print(indexes[1]) #gives us the desired 1,0,1