使用numpy.where()或类似方法从矩阵中的行获取特定值

时间:2017-10-30 20:58:37

标签: python arrays numpy

作为一名MWE,我有一个2d numpy数组:

import numpy as np
n1 = np.array([[6,7,8,1], [5,2,4,8], [3,4,2,1], [8,7,2,10]])
n1
array([[ 6,  7,  8,  1],
       [ 5,  2,  4,  8],
       [ 3,  4,  2,  1],
       [ 8,  7,  2, 10]])

我想得到第一行出现“6”和“8”的索引;其中'2'和'8'出现在第二行;其中'3'和'4'出现在第三行,其中'7'和'2'出现在最后一行。也就是说,我有一个numpy数组列表:

list1 = [np.array([6, 8]), np.array([2, 8]), np.array([3, 4]), np.array([7, 2])]

我希望[n1[i, np.where(ar1)] for i in range(len(list1))]或类似的东西返回一个新列表,其中包含list1中值的列:

returned_list = [np.array([0, 2]), np.array([1, 3]), np.array([0, 1]), np.array([1, 2])]

显然我已经尝试[n1[i, np.where(ar1)] for i in range(len(list1))]了。有什么想法吗?

1 个答案:

答案 0 :(得分:2)

这是使用NumPy broadcasting进行数组输出的那个 -

In [117]: arr_list1 = np.array(list1)

In [118]: mask = (n1[:,:,None] == arr_list1[:,None,:]).any(2)

In [119]: np.where(mask)[1].reshape(-1,2)
Out[119]: 
array([[0, 2],
       [1, 3],
       [0, 1],
       [1, 2]])

<强>解释

  • 基本上,我们通过引入长度为n1arr_list1的单身dims / dims,将3D1扩展到None/np.newaxis数组,例如当相互比较时,将导致与这两个数组中的最后一个轴进行元素比较,作为完整的3D数组。

  • 然后我们在最后一个轴上查找ANY匹配,其长度为2,对应于每行arr_list1中的两个元素。这给了我们一个2D数组。最后,我们需要匹配的行索引,因此np.where()[1]

逐步进行仔细研究 -

1)输入:

In [124]: n1
Out[124]: 
array([[ 6,  7,  8,  1],
       [ 5,  2,  4,  8],
       [ 3,  4,  2,  1],
       [ 8,  7,  2, 10]])

In [125]: arr_list1
Out[125]: 
array([[6, 8],
       [2, 8],
       [3, 4],
       [7, 2]])

2)比较:

In [126]: (n1[:,:,None] == arr_list1[:,None,:])
Out[126]: 
array([[[ True, False],
        [False, False],
        [False,  True],
        [False, False]],

       [[False, False],
        [ True, False],
        [False, False],
        [False,  True]],

       [[ True, False],
        [False,  True],
        [False, False],
        [False, False]],

       [[False, False],
        [ True, False],
        [False,  True],
        [False, False]]], dtype=bool)

3)ANY减少:

In [127]: (n1[:,:,None] == arr_list1[:,None,:]).any(2)
Out[127]: 
array([[ True, False,  True, False],
       [False,  True, False,  True],
       [ True,  True, False, False],
       [False,  True,  True, False]], dtype=bool)

In [128]: mask = _

4)最后得到匹配的行索引:

In [130]: np.where(mask)[1].reshape(-1,2)
Out[130]: 
array([[0, 2],
       [1, 3],
       [0, 1],
       [1, 2]])