给定2D NumPy数组a
和存储在index
中的索引列表,必须有一种非常有效地提取列表值的方法。使用如下的for循环需要大约5 ms,这对于提取的2000个元素来说似乎非常慢:
import numpy as np
import time
# generate dummy array
a = np.arange(4000).reshape(1000, 4)
# generate dummy list of indices
r1 = np.random.randint(1000, size=2000)
r2 = np.random.randint(3, size=2000)
index = np.concatenate([[r1], [r2]]).T
start = time.time()
result = [a[i, j] for [i, j] in index]
print time.time() - start
如何提高提取速度? np.take
在这里似乎不合适,因为它会返回一个2D数组而不是一维数组。
答案 0 :(得分:2)
您可以使用advanced indexing,这基本上是指从index
数组中提取行索引和列索引,然后使用它从a
中提取值,即a[index[:,0], index[:,1]]
- < / p>
%timeit a[index[:,0], index[:,1]]
# 12.1 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit [a[i, j] for [i, j] in index]
# 2.22 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
答案 1 :(得分:2)
另一个选项是numpy.ravel_multi_index
,它可以让您避免手动编制索引。
np.ravel_multi_index(index.T, a.shape)