在3d阵列上的高效Numpy切片

时间:2017-12-08 16:16:50

标签: python-2.7 numpy

我正在尝试找到一种最有效的方法来为3D numpy数组进行切片。这是数据的子集,仅用于测试目的:

in_arr =np.array([[[0,1,2,5],[2,3,2,6],[0,1,3,2]],[[1,2,3,4],[3,1,0,5],[2,4,0,1]]])
indx =[[3,1,2],[2,0,1]]

我需要按照规定获得indx的值。例如,indx [0] [0]为3,所以我正在寻找in_arr [0] [0]的第3个元素,在本例中为5。

我有以下代码可以做我需要做的事情,但时间复杂度是n ^ 2,我不满意。

list_in =[]
for x in range(len(indx)):
    arr2 = []
    for y in range(len(indx[x])):
        arr2.append(in_arr[x][y][indx[x][y]])
        #print in_arr[x][y][indx[x][y]]
    list_in.append(arr2)
print list_in

我正在寻找一种非常快速有效的方法来为大型数据集执行相同的任务。

1 个答案:

答案 0 :(得分:1)

您可以使用广播的索引数组有效地完成此操作;例如:

i1 = np.arange(2)[:, np.newaxis]
i2 = np.arange(3)[np.newaxis, :]
i3 = np.array(indx)
in_arr[i1, i2, i3]
# array([[5, 3, 3],
#        [3, 3, 4]])

这里的numpy有效地匹配三个索引数组的条目,并从in_arr中提取相关条目:[:, np.newaxis][np.newaxis, :]条款的原因是它通过numpy的broadcasting规则重新整形三个数组。