我有两个2D numpy
数组,
import numpy as np
a = np.array([[ 1, 15, 16, 200, 10],
[ -1, 10, 17, 11, -1],
[ -1, -1, 20, -1, -1]])
g = np.array([[ 1, 12, 15, 100, 11],
[ 2, 13, 16, 200, 12],
[ 3, 14, 17, 300, 13],
[ 4, 17, 18, 400, 14],
[ 5, 20, 19, 500, 16]])
我想要做的是,对于g
的每一列,检查它是否包含a
的相应列中的任何元素。对于第一列,我想检查[1,2,3,4,5]
中是否显示任何值[1,-1,-1]
并返回True
。对于第二个,我想返回False
,因为[12,13,14,17,20]
中没有[15,10,-1]
中的元素。目前,我使用Python的列表理解来做到这一点。运行
result = [np.any(np.in1d(g[:,i], a[:, i])) for i in range(5)]
计算正确的结果,但a
有很多列时速度会变慢。是否有更“纯numpy
”方式做同样的事情?我觉得应该有一个axis
关键字可以添加到numpy.in1d
函数中,但是没有任何关键字...
答案 0 :(得分:1)
您可以按行处理,而不是按列处理输入。例如,您可以查看a
列中是否存在g
第一行的任何元素,以便您可以停止处理找到该元素的列。
idx = arange(a.shape[1])
result = empty((idx.size,), dtype=bool)
result.fill(False)
for j in range(a.shape[0]):
#delete this print in production
print "%d line, I look only at columns " % (j + 1), idx
line_pruned = take(a[j], idx)
g_pruned = take(g, idx, axis=1)
positive_idx = where((g_pruned - line_pruned) == 0)[1]
#delete this print in production
print "positive hit on the ", positive_idx, " -th columns"
put(result, positive_idx, True)
idx = setdiff1d(idx, positive_idx)
if not idx.size:
break
要了解它的工作原理,我们可以考虑不同的输入:
a = np.array([[ 0, 15, 16, 200, 10],
[ -1, 10, 17, 11, -1],
[ 1, -1, 20, -1, -1]])
g = np.array([[ 1, 12, 15, 100, 11],
[ 2, 13, 16, 200, 12],
[ 3, 14, 17, 300, 13],
[ 4, 17, 18, 400, 14],
[ 5, 20, 19, 500, 16]])
脚本的输出是:
1 line, I look only at columns [0 1 2 3 4]
positive hit on the [2 3] -th columns
2 line, I look only at columns [0 1 4]
positive hit on the [] -th columns
3 line, I look only at columns [0 1 4]
positive hit on the [0] -th columns
基本上你可以看到在循环的第二轮和第三轮中你是如何处理第二和第四列的。
此解决方案的性能实际上取决于许多因素,但如果您可能达到许多True
值,并且问题有很多行,则会更快。这当然也取决于输入,而不仅仅取决于形状。
答案 1 :(得分:1)
我会使用广播技巧,但这在很大程度上取决于阵列的大小和可用的RAM量:
M = g.reshape(g.shape+(1,)) - a.T.reshape((1,a.shape[1],a.shape[0]))
np.any(np.any(M == 0, axis=0), axis=1)
# returns:
# array([ True, False, True, True, False], dtype=bool)
使用一张纸和一支笔(以及较小的测试数组)(见下文)更容易解释,但基本上你是在g
中为每一列创建副本(a
中的每一行都有一个副本{1}})并从这些副本中减去a
中相应列中的单个元素。与原始算法类似,只是矢量化。
警告:如果任何数组g
或a
是1D,则需要强制它变为2D,使其形状至少为{ {1}}。
速度提升:
仅基于您的数组:因子~20
较大的数组:因子~80
(1,n)
附图解释广播: