假设我有以下两个数组:
a = array([(1, 'L', 74.423088306605), (5, 'H', 128.05441039929008),
(2, 'L', 68.0581377353869), (0, 'H', 88.15726964130869),
(4, 'L', 97.4501582588212), (3, 'H', 92.98550136344437),
(7, 'L', 87.75945631669309), (6, 'L', 90.43196739694255),
(8, 'H', 111.13662092749307), (15, 'H', 91.44444608631304),
(10, 'L', 85.43615908319185), (11, 'L', 78.11685661303494),
(13, 'H', 108.2841293816308), (17, 'L', 74.43917911042259),
(14, 'H', 64.41057325770373), (9, 'L', 27.407214746467943),
(16, 'H', 81.50506434964355), (12, 'H', 97.79700070323196),
(19, 'L', 51.139258140713025), (18, 'H', 118.34835768605957)],
dtype=[('id', '<i4'), ('name', 'S1'), ('value', '<f8')])
b = array([ 0, 3, 5, 8, 12, 13, 14, 15, 16, 18], dtype=int32)
我想从a
中选择id
中b
来选择的元素。也就是说,b
不是索引数组。它包含ids
个观察结果。我怎么能在numpy中做到这一点?
感谢您的帮助。
答案 0 :(得分:5)
你应该得到你想要的东西
indeces = [i for i,id in enumerate(a['id']) if id in b]
suba = a[indeces]
print(suba)
>>>array([(5, 'H', 128.05441039929008), (0, 'H', 88.15726964130869),
(3, 'H', 92.98550136344437), (8, 'H', 111.13662092749307),
(15, 'H', 91.44444608631304), (13, 'H', 108.2841293816308),
(14, 'H', 64.41057325770373), (16, 'H', 81.50506434964355),
(12, 'H', 97.79700070323196), (18, 'H', 118.34835768605957)],
dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
答案 1 :(得分:5)
以下工作比Francesco对您的样本数组的方法快几倍:
In [7]: a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
Out[7]:
array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
(5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
(12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
(14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
(16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)],
dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
In [8]: %timeit a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
100000 loops, best of 3: 11.6 us per loop
In [9]: %timeit indices = [i for i,id in enumerate(a['id']) if id in b]; a[indices]
10000 loops, best of 3: 66.9 us per loop
要了解它的工作原理,请看一下:
In [10]: a['id'][None, :] == b[:, None]
Out[10]:
array([[False, False, False, True, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False],
... # several rows removed
[False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, True]], dtype=bool)
这是一个数组,其行数与b
中的元素数量相同,列数与a
中的元素数量相同。 np.argmax
然后在每一行中找到第一个True
的位置,这是b
中a['id']
的相应元素首次出现的索引。
如上所示,对于小型数组,这在性能方面优于python。但是,如果a
或b
过大,那么bool
的中间数组的大小可能会削弱性能。此外,np.argmax
必须搜索整行,它从不会提前退出循环,如果a
太长,这不是一件好事。我在使用类似方法的this question的答案中做了一些时间安排,并且仍然是适合大型数组的方法。
Francesco的方法肯定不那么简单,更容易理解,对于数组大小的样本,性能差异是无关紧要的,我必须承认。但它不会让你感觉像this ...
答案 2 :(得分:0)
sorted = numpy.sort(a)
sorted[b]
array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
(5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
(12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
(14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
(16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)],
dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
只要数组中的行数与id一样多。