如何有效地采样numpy?

时间:2017-03-23 06:00:54

标签: numpy scipy

我在numpy中有一个3D布尔数组。我想从具有True值(如果有)的那些中选择一个随机元素。挑选任何True元素的概率应该相同。我需要所选元素的坐标。

有效但不是特别快的方法:

  • 随机挑选元素并检查它们是否为真

  • 形成所有True元素的索引列表(例如使用numpy.nonzero),然后从中随机选择

该数组通常是一个大小为256x256x256的多维数据集。真实元素约占总数的1%到10%。

1 个答案:

答案 0 :(得分:2)

我不确定您的方法是否可以以主要方式改进。无论如何,这里有一个import numpy as np from timeit import timeit def pick_true(data, n): nz = np.flatnonzero(data) return np.unravel_index(np.random.choice(nz, n), data.shape) def pick_true_2(data, n, p): pick = np.random.randint(0, data.size, (int(round(n/p)),)) return np.unravel_index(pick[data.ravel()[pick]], data.shape) data = np.random.random((256,256,256)) < 0.01 print(pick_true(data, 10), data[pick_true(data, 10)]) print('indexing small {:6.4f} secs'.format(timeit(lambda: pick_true(data, 100), number=10)/10)) print('indexing large {:6.4f} secs'.format(timeit(lambda: pick_true(data, 10000), number=10)/10)) print(pick_true_2(data, 10, 0.01), data[pick_true_2(data, 10, 0.01)]) print('non-indexing small {:6.4f} secs'.format(timeit(lambda: pick_true_2(data, 100, 0.01), number=10)/10)) print('non-indexing large {:6.4f} secs'.format(timeit(lambda: pick_true_2(data, 10000, 0.01), number=10)/10)) 解决方案和一个用于比较的反复试验:

(array([  8, 164, 247, 160, 154, 147, 146,  73,  89,   1]),
array([ 89,   0,   7, 217,  46,  45, 139, 205, 163,  92]),
array([ 70, 129,  92,   7,  14, 155, 148,  51, 146, 176]))
[ True  True  True  True  True  True  True  True  True  True]
indexing small 0.0072 secs
indexing large 0.0090 secs
(array([ 29, 113,   7,  18, 159, 203,  97,  45]),
array([227, 251,   8, 137,  61, 226, 170,  17]),
array([249,  28, 160,  99,  99, 191, 174, 234]))     
[ True  True  True  True  True  True  True  True]
non-indexing small 0.0002 secs
non-indexing large 0.0206 secs

示例输出:

data