Numpy.where与值列表一起使用

时间:2018-04-23 18:19:37

标签: python numpy

我有一个2d和1d阵列。我希望找到两行至少包含一次1d数组中的值,如下所示:

import numpy as np

A = np.array([[0, 3, 1],
           [9, 4, 6],
           [2, 7, 3],
           [1, 8, 9],
           [6, 2, 7],
           [4, 8, 0]])

B = np.array([0,1,2,3])

results = []

for elem in B:
    results.append(np.where(A==elem)[0])

这有效并产生以下数组:

[array([0, 5], dtype=int64),
 array([0, 3], dtype=int64),
 array([2, 4], dtype=int64),
 array([0, 2], dtype=int64)]

但这可能不是最好的处理方式。按照这个问题给出的答案(Search Numpy array with multiple values),我尝试了以下解决方案:

out1 = np.where(np.in1d(A, B))

num_arr = np.sort(B)
idx = np.searchsorted(B, A)
idx[idx==len(num_arr)] = 0 
out2 = A[A == num_arr[idx]]

但是这些给了我不正确的值:

In [36]: out1
Out[36]: (array([ 0,  1,  2,  6,  8,  9, 13, 17], dtype=int64),)

In [37]: out2
Out[37]: array([0, 3, 1, 2, 3, 1, 2, 0])

感谢您的帮助

3 个答案:

答案 0 :(得分:2)

由于您正在处理2D数组 * ,因此您可以使用广播将BA的版本进行比较。这将为您提供各种指数的分层形状。然后,您可以使用np.unravel_index反转结果并在原始数组中获取相应的索引。

In [50]: d = np.where(B[:, None] == A.ravel())[1]

In [51]: np.unravel_index(d, A.shape)
Out[51]: (array([0, 5, 0, 3, 2, 4, 0, 2]), array([0, 2, 2, 0, 0, 1, 1, 2]))                 
                       ^
               # expected result 

<子> *来自documentation:对于三维数组,这在代码行方面肯定是高效的,并且对于小数据集,它也可以在计算上有效。但是,对于大型数据集,创建大型3-d阵列可能会导致性能低下。 此外,广播是一种强大的工具,用于编写简短且通常直观的代码,可以在C中非常有效地进行计算。但是,有时广播会为特定算法使用不必要的大量内存。在这些情况下,最好在Python中编写算法的外部循环。这也可能产生更易读的代码,因为随着广播中维度的数量增加,使用广播的算法往往变得更难以解释。

答案 1 :(得分:1)

这是你想要的东西吗?

将numpy导入为np 来自itertools导入组合

A = np.array([[0, 3, 1],
           [9, 4, 6],
           [2, 7, 3],
           [1, 8, 9],
           [6, 2, 7],
           [4, 8, 0]])

B = np.array([0,1,2,3])

for i in combinations(A, 2):
    if np.all(np.isin(B, np.hstack(i))):
        print(i[0], ' ', i[1])

打印以下内容:

[0 3 1]   [2 7 3]
[0 3 1]   [6 2 7]

注意:此解决方案不要求行是连续的。如果需要,请告诉我。

答案 2 :(得分:1)

如果您需要知道 A 的每一行是否包含数组 B 的任何元素而不关心它是 B 的哪个特定元素,可以使用以下脚本:

输入:

np.isin(A,B).sum(axis=1)>0 

输出:

array([ True, False,  True,  True,  True,  True])