优化生成变量的拒绝方法

时间:2019-04-20 12:06:34

标签: python optimization random plot probability

我对生成连续随机变量的拒绝方法的优化存在问题。我有一个密度:f(x) = 3/2 (1-x^2)。这是我的代码:

import random
import matplotlib.pyplot as plt
import numpy  as np
import time
import scipy.stats as ss

a=0   # xmin
b=1   # xmax

m=3/2 # ymax
variables = [] #list for variables

def f(x):
    return 3/2 * (1 - x**2)  #probability density function

reject = 0   # number of rejections
start = time.time()
while len(variables) < 100000:  #I want to generate 100 000 variables
    u1 = random.uniform(a,b)
    u2 = random.uniform(0,m)

    if u2 <= f(u1):
        variables.append(u1)
    else:
        reject +=1
end = time.time()

print("Time: ", end-start)
print("Rejection: ", reject)
x = np.linspace(a,b,1000)
plt.hist(variables,50, density=1)
plt.plot(x, f(x))
plt.show()

ss.probplot(variables, plot=plt)
plt.show()

我的第一个问题:我的概率图设计正确吗? 第二,标题中的内容。如何优化该方法?我想获得一些建议以优化代码。现在,该代码大约需要0.5秒,并且大约有5万次拒绝。是否可以减少拒绝的时间和数量?如果需要,我可以使用另一种生成变量的方法进行优化。

3 个答案:

答案 0 :(得分:1)

  

我的第一个问题:我的概率图设计正确吗?

不。与默认的正态分布比较。您必须将函数f(x)打包到stats.rv_continuous派生的类中,使其成为_pdf方法,然后将其传递给probplot

  

第二个是标题中的内容。如何优化该方法?是否可以减少拒绝的时间和数量?

当然,您掌握了NumPy向量功能的强大功能。永远不要编写显式循环-vectoriz,vectorize和vectorize!

请看下面的修改代码,而不是单个循环,所有操作都是通过NumPy向量完成的。我的计算机上的时间从0.19减少到0.003,用于处理100000个样本(至强,Win10 x64,Anaconda Python 3.7)。

import numpy as np
import scipy.stats as ss
import matplotlib.pyplot as plt
import time

a = 0.  # xmin
b = 1.  # xmax

m = 3.0/2.0 # ymax

def f(x):
    return 1.5 * (1.0 - x*x)  # probability density function

start  = time.time()

N = 100000
u1 = np.random.uniform(a, b, N)
u2 = np.random.uniform(0.0, m, N)

negs = np.empty(N)
negs.fill(-1)
variables = np.where(u2 <= f(u1), u1, negs) # accepted samples are positive or 0, rejected are -1

end = time.time()

accept = np.extract(variables>=0.0, variables)
reject = N - len(accept)

print("Time: ", end-start)
print("Rejection: ", reject)

x = np.linspace(a, b, 1000)
plt.hist(accept, 50, density=True)
plt.plot(x, f(x))
plt.show()

ss.probplot(accept, plot=plt) # against normal distribution
plt.show()

关于减少拒绝的数量,您可以使用逆方法以0个拒绝进行采样,因为它是三次方程式,因此操作起来很容易

更新

以下是用于probplot的代码:

class my_pdf(ss.rv_continuous):
    def _pdf(self, x):
        return 1.5 * (1.0 - x*x)

ss.probplot(accept, dist=my_pdf(a=a, b=b, name='my_pdf'), plot=plt)

您应该得到类似的东西

enter image description here

答案 1 :(得分:1)

关于第一个问题,scipy.stats.probplot将样本与正态分布的分位数进行比较。如果您希望它与f(x)分布的分位数进行比较,请检查dist的{​​{1}}参数。

就使此采样过程更快而言,避免循环通常是可行的方法。用下面的代码替换probplotstart = ...之间的代码,对我来说提高了20倍以上。

end = ...

请注意,每次运行时,这将为您大约 100000个接受的样本。您可以稍微提高n_before_accept_reject = 150000 u1 = np.random.uniform(a, b, size=n_before_accept_reject) u2 = np.random.uniform(0, m, size=n_before_accept_reject) variables = u1[u2 <= f(u1)] reject = n_before_accept_reject - len(variables) 的值,以有效地保证n_before_accept_reject始终具有> 100000个可接受的值,然后在必要时限制变量的大小以恰好返回100000。

答案 2 :(得分:1)

其他人谈到了概率图,我将讨论拒绝算法的效率。

接受/拒绝方案基于m(x),即“主函数”。主化函数应具有两个属性:1)m(x)≥f(x)∀x; 2)m(x),按比例缩放为分布时,应易于从中生成值。 您使用了常数函数m = 3/2,该函数既满足两个要求,又没有非常紧密地限制f(x)。从零到一积分,面积为3/2。作为有效密度函数的f(x)的面积为1。因此,∫f(x))/∫m(x))= 1 /(3/2)= 2/3。换句话说,您从主函数生成的值的2/3被接受,而您拒绝了1/3的尝试。

您需要一个m(x)来为f(x)提供更紧密的界限。我选择了一条在x = 1/2处与f(x)相切的线。通过一点微积分就可以得出斜率,我得出了m(x) = 15/8 - 3x/2

Plot of m(x) and f(x)

m(x)的此选择的面积为9/8,因此仅将拒绝1/9的值。再根据此m(x)将x的逆变换生成器进行微积分,得到x = (5 - sqrt(25 - 24U)) / 4,其中U是均匀(0,1)随机变量。

这是基于原始版本的实现。我将拒绝方案包装在一个函数中,并使用列表理解而不是附加到列表来创建值。如您所见,如果您运行此程序,它会比原始版本产生更少的拒绝。

import random
import matplotlib.pyplot as plt
import numpy  as np
import time
import math
import scipy.stats as ss

a = 0   # xmin
b = 1   # xmax

reject = 0   # number of rejections

def f(x):
    return 3.0 / 2.0 * (1.0 - x**2)  #probability density function

def m(x):
    return 1.875 - 1.5 * x

def generate_x():
    global reject
    while True:
        x = (5.0 - math.sqrt(25.0 - random.uniform(0.0, 24.0))) / 4.0
        u = random.uniform(0, m(x))
        if u <= f(x):
            return x 
        reject += 1    

start = time.time()
variables = [generate_x() for _ in range(100000)]
end = time.time()

print("Time: ", end-start)
print("Rejection: ", reject)
x = np.linspace(a,b,1000)
plt.hist(variables,50, density=1)
plt.plot(x, f(x))
plt.show()