我通过一组概率probs
从多项分布生成绘制向量,其中每个绘制是probs
中所选项的索引:
import numpy as np
def sample_mult(K, probs):
result = np.zeros(num_draws, dtype=np.int32)
for n in xrange(K):
draws = np.random.multinomial(1, probs)
result[n] = np.where(draws == 1)[0][0]
return result
这可以加快吗?一遍又一遍地调用np.random.multinomial
似乎效率低下(np.where
也可能很慢。)
timeit
说The slowest run took 6.72 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 3: 18.9 µs per loop
答案 0 :(得分:6)
您可以使用size
选项与np.random.multinomial
一起使用随机样本行而不是仅使用默认size=1
的一行输出,然后使用.argmax(1)
来模拟np.where()[0][0]
{3}}行为。
因此,我们将有一个矢量化解决方案,如此 -
result = (np.random.multinomial(1,probs,size=K)==1).argmax(1)
答案 1 :(得分:0)
“ =”的p =参数可以做到这一点(并避免使用argmax):
result = np.random.choice(len(probs), K, p=probs)