我有一个2D numpy数组如下:
import numpy as np
foo = np.array([[(i+1)*(j+1) for i in range(10)] for j in range(5)])
#array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
# [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
# [ 3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
# [ 4, 8, 12, 16, 20, 24, 28, 32, 36, 40],
# [ 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]])
我使用np.nonzero创建了一些过滤条件:
csum = np.sum(foo,axis=0)
#array([ 15, 30, 45, 60, 75, 90, 105, 120, 135, 150])
rsum = np.sum(foo,axis=1)
#array([ 55, 110, 165, 220, 275])
cfilter = np.nonzero(csum > 80)
#(array([5, 6, 7, 8, 9]),)
rfilter = np.nonzero(rsum < 165)
#(array([0, 1]),)
现在是否有一些优雅的numpy切片方法来获取rfilter中的r和cfilter中的c的所有foo [r,c]组合?即我想获得以下输出:
array([[ 6, 7, 8, 9, 10],
[12, 14, 16, 18, 20]])
注意:我知道从阵列中获取块的基本切片选择很容易,但在更高级的用例中,cfilter和rfilter中的索引不一定是彼此相邻的。
非常感谢!
答案 0 :(得分:5)
要对叉制品编制索引,请使用np.ix_
:
foo[np.ix_(*(rfilter + cfilter))]
您可以直接使用布尔索引(即不使用np.nonzero
):
foo[np.ix_(np.sum(foo, axis=1) < 165, np.sum(foo, axis=0) > 80)]
请注意,所有np.ix_
都会适当地添加轴以提供可以一起广播的索引数组:
>>> np.ix_(*(rfilter + cfilter))
(array([[0],
[1]]), array([[5, 6, 7, 8, 9]]))
答案 1 :(得分:1)
另一种方法是使用索引两次:
In [167]: foo[rsum<165][:,csum>80]
Out[167]:
array([[ 6, 7, 8, 9, 10],
[12, 14, 16, 18, 20]])
它可读且相当快:
In [168]: %timeit foo[rsum<165][:,csum>80]
100000 loops, best of 3: 9.66 us per loop
In [170]: %timeit foo[np.ix_(rsum<165, csum>80)]
100000 loops, best of 3: 16.4 us per loop
PS:创建foo
的更快捷方式是
In [31]: np.multiply.outer(range(1,6),range(1,11))
Out[31]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
[ 3, 6, 9, 12, 15, 18, 21, 24, 27, 30],
[ 4, 8, 12, 16, 20, 24, 28, 32, 36, 40],
[ 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]])
In [32]: %timeit np.multiply.outer(range(1,6),range(1,11))
100000 loops, best of 3: 14.2 us per loop
In [33]: %timeit np.array([[(i+1)*(j+1) for i in range(10)] for j in range(5)])
10000 loops, best of 3: 26.6 us per loop
答案 2 :(得分:0)
你实际上并不需要非零。像(csum> 80)这样的表达式产生新的矩阵。你想要的是(csum&gt; 80)&amp;&amp; (rsum&lt; 165),但&amp;&amp;没有在矩阵上定义。但是,*是,它在布尔矩阵上也是如此。你唯一的问题是你的csum和rsum数组不是正确的形状。但如果你正确堆叠它们,它们就可以播出。
csum = np.hstack (sum (foo, axis=0))
rsum = np.vstack (sum (foo, axis=1))
print foo[(csum > 80) * (hsum < 165)]
唯一的缺点是,它会生成您在一维数组中要求的单元格的值。您将需要重塑()它以获得您要求的格式。