我有两个2D数组,一个数字和一个布尔值:
x =
array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
[ 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
[ 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.],
[ 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.],
[ 9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])
idx =
array([[False, False, False, False, False, False, False, False, False, False],
[False, True, True, True, True, True, False, False, False, False],
[False, True, True, True, True, True, False, False, False, False],
[False, True, True, True, True, True, False, False, False, False],
[False, False, False, True, True, True, True, False, False, False],
[False, False, False, False, True, True, True, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, False, True, False, False],
[False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False]], dtype=bool)
当我索引数组时,它返回一维数组:
x[idx]
array([ 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 3., 3., 3.,
3., 3., 4., 4., 4., 4., 5., 5., 5., 6., 7.])
如何索引数组并返回具有预期输出的2D数组:
x[idx]
array([[ 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4.],
[ 5., 5., 5.],
[ 6.],
[ 7.]])
答案 0 :(得分:3)
您的命令返回一维数组,因为如果没有(a)破坏通常需要的列结构,则无法实现。例如,您请求的输出中的7
最初属于第7列,现在它位于第0列; (b)numpy
不支持在同一维度上支持不同大小的高维数组。我的意思是numpy不能有一个前三行长度为5的数组,长度为4的第四行等等 - 所有行(相同的维度)需要具有相同的长度。
我认为你可以期待的最好的结果是数组(而不是2D数组)。这就是我构建它的方式,尽管可能有更好的方法我不知道:
In [9]: from itertools import izip
In [11]: array([r[ridx] for r, ridx in izip(x, idx) if ridx.sum() > 0])
Out[11]:
array([array([ 1., 1., 1., 1., 1.]), array([ 2., 2., 2., 2., 2.]),
array([ 3., 3., 3., 3., 3.]), array([ 4., 4., 4., 4.]),
array([ 5., 5., 5.]), array([ 6.]), array([ 7.])], dtype=object)
答案 1 :(得分:0)
编辑:这会创建一个列表数组
np.array([val[idx[i]].tolist() for i,val in enumerate(x) if len(val[idx[i]].tolist()) > 0])
array([[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
[5.0, 5.0, 5.0],
[6.0],
[7.0]], dtype=object)