python 2的高性能加权随机选择?

时间:2018-03-08 14:28:04

标签: python numpy random pypy

我有以下python方法,它从由其他序列随机加权的序列“seq”中选择加权随机元素,其中包含seq中每个元素的权重:

def weighted_choice(seq, weights):
    assert len(seq) == len(weights)

    total = sum(weights)
    r = random.uniform(0, total)
    upto = 0
    for i in range(len(seq)):
        if upto + weights[i] >= r:
            return seq[i]
        upto += weights[i]
    assert False, "Shouldn't get here"

如果我使用1000个元素序列调用上述一百万次,如下所示:

seq = range(1000)
weights = []
for i in range(1000):
    weights.append(random.randint(1,100))

st=time.time()
for i in range(1000000):
    r=weighted_choice(seq, weights)
print (time.time()-st)

它在cpython 2.7中运行大约45秒,在cpython 3.6中运行70秒。 它在pypy 5.10中以2.3秒完成,这对我来说很好,遗憾的是我出于某些原因不能使用pypy。

有关如何在cpython上加速此功能的任何想法?我对其他实现(通过算法或通过外部库,如numpy)感兴趣,如果它们表现更好。

ps:python3有random.choices有权重,它运行大约23秒,这比上面的函数要好,但仍然比pypy运行慢十倍。

我用这种方式尝试了numpy:

weights=[1./1000]*1000
st=time.time()
for i in range(1000000):
    #r=weighted_choice(seq, weights)
    #r=random.choices(seq, weights)
    r=numpy.random.choice(seq, p=weights)
print (time.time()-st)

它跑了70秒。

2 个答案:

答案 0 :(得分:2)

您可以使用numpy.random.choicegenes = Gene.objects.filter(gene_name__in=self.gene_list).values('gene_name') genes_set = set(gene.gene_name for gene in genes) not_in_db = set(gene_list) - genes_set 参数是权重)。通常p函数是矢量化的,因此以近C速度运行。

实施为:

numpy

编辑:

时间:

def weighted_choice(seq, weights):
    w = np.asarray(weights)
    p = w / w.sum()  # can skip if weights always sum to 1
    return np.random.choice(seq, p=w)

答案 1 :(得分:0)

您可以使用numpy采用此方法。如果你忽略for循环,你可以通过索引你需要的位置来获得numpy的真正力量

#Untimed since you did not
seq = np.arange(1000)
weights = np.random.randint(1,100,(1000,1))


def weights_numpy(seq,weights,iterations):
    """
    :param seq: Input sequence
    :param weights: Input Weights
    :param iterations: Iterations to run
    :return: 
    """
    r = np.random.uniform(0,weights.sum(0),(1,iterations)) #create array of choices
    ar = weights.cumsum(0) # get cumulative sum
    return seq[(ar >= r).argmax(0)] #get indeces of seq that meet your condition

时间(python 3,numpy '1.14.0'

%timeit weights_numpy(seq,weights,1000000)
4.05 s ± 256 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

哪个比PyPy慢一点,但很难......