假设我们有一个numpy数组v
v=np.array([3, 5])
现在我们使用下面的代码找到一个新的矢量说w
v1=np.array(range(v[0]+1))
v2=np.array(range(v[1]+1))
w=np.array(list(itertools.product(v1,v2)))
所以看起来像这样,
array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5],
[1, 0],
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[1, 5],
[2, 0],
[2, 1],
[2, 2],
[2, 3],
[2, 4],
[2, 5],
[3, 0],
[3, 1],
[3, 2],
[3, 3],
[3, 4],
[3, 5]])
现在,我们需要找到对应于每对中的概率向量,知道每对中的第一个元素遵循二项分布Bin(v [0],0.1),并且每对中的第二个元素遵循二项分布Bin(v [1],0.05)。一种方法是通过这一个班轮
import scipy.stats as ss
prob_vector=np.array(list((ss.binom.pmf(i[0],v[0], 0.1) * ss.binom.pmf(i[1],v[1], 0.05)) for i in w))
输出:
array([5.64086303e-01, 1.48443764e-01, 1.56256594e-02, 8.22403125e-04,
2.16421875e-05, 2.27812500e-07, 1.88028768e-01, 4.94812547e-02,
5.20855312e-03, 2.74134375e-04, 7.21406250e-06, 7.59375000e-08,
2.08920853e-02, 5.49791719e-03, 5.78728125e-04, 3.04593750e-05,
8.01562500e-07, 8.43750000e-09, 7.73780938e-04, 2.03626563e-04,
2.14343750e-05, 1.12812500e-06, 2.96875000e-08, 3.12500000e-10])
但计算需要太多时间,特别是因为我正在迭代几个v向量!
有没有一种有效的方法来计算prob_vector?
由于
答案 0 :(得分:0)
你正在重做很多pmf调用,以及在Python端而不是numpy端做很多事情。我们可以通过计算v1和v2数组来保存这些计算,然后将它们相乘。
import numpy as np
import scipy.stats as ss
import itertools
def orig(x, y):
v = np.array([x, y])
v1 =np.array(range(v[0]+1))
v2=np.array(range(v[1]+1))
w=np.array(list(itertools.product(v1,v2)))
prob_vector=np.array(list((ss.binom.pmf(i[0],v[0], 0.1) * ss.binom.pmf(i[1],v[1], 0.05)) for i in w))
return prob_vector
def faster(x, y):
b0 = ss.binom.pmf(np.arange(x+1), x, 0.1)
b1 = ss.binom.pmf(np.arange(y+1), y, 0.05)
prob_array = b0[:, None] * b1
prob_vector = prob_array.ravel()
return prob_vector
给了我:
In [61]: %timeit orig(3, 5)
4.46 ms ± 82.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [62]: %timeit faster(3, 5)
192 µs ± 4.33 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [63]: %timeit orig(30, 50)
311 ms ± 24.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [64]: %timeit faster(30, 50)
209 µs ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [65]: (orig(30, 50) == faster(30, 50)).all()
Out[65]: True