我想使用numpy.random.choice()
,但要确保抽奖间隔至少一个"间隔":
作为一个具体的例子,
import numpy as np
np.random.seed(123)
interval = 5
foo = np.random.choice(np.arange(1,50), 5) ## 5 random draws between array([ 1, 2, ..., 50])
print(foo)
## array([46, 3, 29, 35, 39])
我希望它们间隔至少interval+1
,即5 + 1 = 6。在上面的例子中,不满足这个条件:应该有另一个随机抽取,因为35和39被4分隔,小于6。
数组array([46, 3, 29, 15, 39])
没问题,因为所有绘制间隔至少为6。
numpy.random.choice(array, size)
在size
中抽取array
次抽奖。还有另一个功能用于检查"间距" numpy数组中的元素之间?我可以使用if / while语句编写上面的内容,但我不确定如何最有效地检查numpy数组中元素的间距。
答案 0 :(得分:3)
这是一个在绘图后插入空格的解决方案:
def spaced_choice(low, high, delta, n_samples):
draw = np.random.choice(high-low-(n_samples-1)*delta, n_samples, replace=False)
idx = np.argsort(draw)
draw[idx] += np.arange(low, low + delta*n_samples, delta)
return draw
示例运行:
spaced_choice(4, 20, 3, 4)
# array([ 5, 9, 19, 13])
spaced_choice(1, 50, 5, 5)
# array([30, 8, 1, 15, 43])
请注意,抽奖然后接受或拒绝和重绘策略可能非常昂贵。在最糟糕的情况下,对于10
个样本,重绘几乎需要半分钟,因为依赖/拒绝率非常差。 insert-the-spaces-afterwards方法没有这种问题。
两个例子的不同方法所需的时间:
low, high, delta, size = 1, 100, 5, 5
add_spaces 0.04245870 ms
redraw 0.11335560 ms
low, high, delta, size = 1, 20, 1, 10
add_spaces 0.03201030 ms
redraw 27881.01527220 ms
代码:
import numpy as np
import types
from timeit import timeit
def f_add_spaces(low, high, delta, n_samples):
draw = np.random.choice(high-low-(n_samples-1)*delta, n_samples, replace=False)
idx = np.argsort(draw)
draw[idx] += np.arange(low, low + delta*n_samples, delta)
return draw
def f_redraw(low, high, delta, n_samples):
foo = np.random.choice(np.arange(low, high), n_samples)
while any(x <= delta for x in np.diff(np.sort(foo))):
foo = np.random.choice(np.arange(low, high), n_samples)
return foo
for l, h, k, n in [(1, 100, 5, 5), (1, 20, 1, 10)]:
print(f'low, high, delta, size = {l}, {h}, {k}, {n}')
for name, func in list(globals().items()):
if not name.startswith('f_') or not isinstance(func, types.FunctionType):
continue
print("{:16s}{:16.8f} ms".format(name[2:], timeit(
'f(*args)', globals={'f':func, 'args':(l,h,k,n)}, number=10)*100))
答案 1 :(得分:1)
您可以先对数组进行排序,使所有点按升序排列,然后使用np.diff
查找连续值之间的差异。如果任何差异小于interval
,则表明尚未满足条件。即
import numpy as np
interval = 5
foo = np.random.choice(np.arange(1,50),5)
while np.any(np.diff(np.sort(foo)) <= interval):
foo = np.random.choice(np.arange(1,50),5)
print(foo)
哪个循环直到你得到一个numpy数组,其中所有值至少相差interval
。