对2d numpy数组中包含另一个1d numpy数组的所有值的行进行计数的最佳方法是什么?第二个数组可以比第一个数组的长度多列。
elements = np.arange(4).reshape((2, 2))
test_elements = [2, 3]
somefunction(elements, test_elements)
我希望函数返回1。
elements = np.arange(15).reshape((5, 3))
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]])
test_elements = [4, 3]
somefunction(elements, test_elements)
还应返回1。
必须包含1d数组的所有元素。如果连续只能找到几个元素,则不算在内。因此:
elements = np.arange(15).reshape((5, 3))
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]])
test_elements = [3, 4, 10]
somefunction(elements, test_elements)
还应返回0。
答案 0 :(得分:0)
创建一个找到的元素的布尔数组,然后按行使用,这将避免同一行中出现多个值,最后通过使用sum来对行进行计数,
np.any(np.isin(elements, test), axis=1).sum()
输出
>>> elements
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
>>> test = [1, 6, 7, 4]
>>> np.any(np.isin(elements, test), axis=1).sum()
3
答案 1 :(得分:0)
(编辑:好的,现在我实际上有更多时间弄清楚发生了什么。)
这里有两个问题:
问题可以分为两部分:
我们知道,对于足够大的输入,在NumPy中循环浏览行更快,在纯Python中循环浏览 slower 。
作为参考,让我们考虑以下两种方法:
# pure Python approach
def all_in_by_row_flt(arr, elems=ELEMS):
return sum(1 for row in arr if all(e in row for e in elems))
# NumPy apprach (based on @Mstaino answer)
def all_in_by_row_np(arr, elems=ELEMS):
def _aaa_helper(row, e=elems):
return np.isin(e, row)
return np.sum(np.all(np.apply_along_axis(_aaa_helper, 1, arr), 1))
然后,考虑子集检查操作,如果输入使得检查在更少的循环内执行,则纯Python循环比NumPy更快。相反,如果需要足够多的循环,则NumPy实际上可以更快。
最重要的是,存在遍历行的循环,但是由于子集检查操作是二次的并且具有不同的常数系数,因此尽管在NumPy中行循环更快,但在某些情况下(因为行数会足够大),则在纯Python中总体操作 更快。
这是我在较早的基准测试中遇到的情况,并且对应于子集检查始终(或几乎)False
并且确实在少数循环中失败的情况。
一旦子集检查开始需要更多的循环,仅Python方法就开始落后,并且对于大多数(如果不是全部)行,子集检查实际上是True
的情况,NumPy方法实际上更快。
NumPy和纯Python方法之间的另一个主要区别是,纯Python使用惰性求值,而NumPy不使用懒惰求值,并且实际上需要创建潜在的大型中间对象,从而减慢了计算速度。
最重要的是,NumPy对行进行两次迭代(在sum()
中进行一次迭代,在np.apply_along_axis()
中进行一次迭代),而纯Python仅执行一次。
使用set().issubset()
的其他方法,例如来自@ GZ0答案:
def all_in_by_row_set(arr, elems=ELEMS):
elems = set(elems)
return sum(map(elems.issubset, row))
与子集检查中显式编写嵌套循环的时间不同,但是它们仍然受外部循环慢的困扰。
答案是使用Cython或Numba。 这样做的目的是始终保持类似于NumPy的速度(读为C)(不仅对于足够大的输入),还可以进行惰性计算并最小化行的循环次数。
Cython方法(在IPython中使用%load_ext Cython
魔术实现)的示例是:
%%cython --cplus -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True
cdef long all_in_by_row_c(long[:, :] arr, long[:] elems) nogil:
cdef long result = 0
I = arr.shape[0]
J = arr.shape[1]
K = elems.shape[0]
for i in range(I):
is_subset = True
for k in range(K):
is_contained = False
for j in range(J):
if elems[k] == arr[i, j]:
is_contained = True
break
if not is_contained:
is_subset = False
break
result += 1 if is_subset else 0
return result
def all_in_by_row_cy(long[:, :] arr, long[:] elems):
return all_in_by_row_c(arr, elems)
类似的Numba代码显示:
import numba as nb
@nb.jit(nopython=True, nogil=True)
def all_in_by_row_jit(arr, elems=ELEMS):
result = 0
n_rows, n_cols = arr.shape
for i in range(n_rows):
is_subset = True
for e in elems:
is_contained = False
for r in arr[i, :]:
if e == r:
is_contained = True
break
if not is_contained:
is_subset = False
break
result += 1 if is_subset else 0
return result
现在,按时间顺序,我们得到以下内容(相对较少的行数):
arr.shape=(100, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 120 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 129 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 2.44 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_set 9.98 ms ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np 13.7 ms ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
arr.shape=(100, 2000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 1.45 ms ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_jit 1.52 ms ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_flt 30.1 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_set 19.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np 18 ms ± 28.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
arr.shape=(100, 3000) elems.shape=(1000,) result=37
Func: all_in_by_row_cy 10.4 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 10.9 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 226 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30.5 ms ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 21.9 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
arr.shape=(100, 4000) elems.shape=(1000,) result=86
Func: all_in_by_row_cy 16.8 ms ± 32.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 17.7 ms ± 42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 385 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 39.5 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 25.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
现在,最后一个块的减速无法用第二维中输入大小的增加来解释。 实际上,如果增加短路率(例如,通过更改随机数组的值范围),则对于最后一个块(输入大小相同),将得到:
arr.shape=(100, 4000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 152 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 173 µs ± 4.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 556 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_set 39.7 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np 31.5 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
请注意,基于set()
的方法与短路率是无关的(因为基于哈希的实现具有~O(1)
检查存在的复杂性,但这是以散列预计算为代价的,这些结果表明这可能不比直接嵌套循环方法要快。
最后,对于更大的行数:
arr.shape=(100000, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy 141 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_jit 150 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_flt 2.6 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 10.1 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 13.7 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 2000) elems.shape=(1000,) result=34
Func: all_in_by_row_cy 1.2 s ± 753 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 1.27 s ± 7.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 24.1 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 19.5 s ± 270 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 18 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 3000) elems.shape=(1000,) result=33859
Func: all_in_by_row_cy 9.79 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 10.3 s ± 5.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 3min 30s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30 s ± 57.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 21.9 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
arr.shape=(100000, 4000) elems.shape=(1000,) result=86376
Func: all_in_by_row_cy 17 s ± 30.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 17.9 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 6min 29s ± 293 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 38.9 s ± 33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np 25.7 s ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
最后,请注意,Cython / Numba代码 可以通过算法进行优化。
答案 2 :(得分:0)
也许有一个更有效的解决方案,但是如果您希望存在test_elements
的“所有”元素的行,则可以反转np.isin
并将其沿行应用,如下所示: / p>
np.apply_along_axis(lambda x: np.isin(test_elements, x), 1, elements).all(1).sum()
答案 3 :(得分:0)
以下是@ norok2解决方案的一种稍微高效(但可读性较差)的变体。
sum(map(set(test_elements).issubset, elements))