如何确定用于优化的初始值

时间:2019-04-28 08:40:58

标签: python optimization scipy minimization

因此,问题在于我们有一堆直方图数据,并且需要使用三个高斯分布和一个指数函数来拟合数据。我当前的代码包含教师给定的适合的初始条件:p_init。但是,这种特定的初始条件仅适用于约4/5的测试数据。如何改进当前代码以适应所有不同的数据集?

我想尝试使用随机数来拟合直方图数据,直到成功为止,但是效果还不好。

import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize as opt

def hist_fit(hist):
    ret = np.array([100.,1.,0.05])
    xmin, xmax, xbinwidth = 0., 2., 0.02
    vx = np.linspace(xmin+xbinwidth/2,xmax-xbinwidth/2,100)
    vy = hist
    vyerr = vy**0.5

    def model(x, norm, mean, sigma, norm2, norm3, c0, c1):
        exponential = c1*np.exp(-x/c0)
        gaussian1 = norm*xbinwidth/(2.*np.pi)**0.5/sigma * \
                    np.exp(-0.5*((x-mean)/sigma)**2)
        gaussian2 = norm2*xbinwidth/(2.*np.pi)**0.5/(0.4) * \
                    np.exp(-0.5*((x-0.1)/0.4)**2)
        gaussian3 = norm3*xbinwidth/(2.*np.pi)**0.5/(0.2) * \
                    np.exp(-0.5*((x-0.7)/0.2)**2)
        return exponential + gaussian1 + gaussian2 + gaussian3

    def fcn(p):    
        expt = model(vx,p[0],p[1],p[2],p[3],p[4],p[5],p[6])
        delta = (vy-expt)/vyerr
        return (delta**2).sum()

    p_init = np.array([10**3.,1.,0.07,1000.,100.,2., 100.])    
    r = opt.minimize(fcn,p_init)
    ret[0] = r.x[0]
    ret[1] = r.x[1]
    ret[2] = r.x[2]
    return ret

if __name__ == '__main__':
    data = np.load('hist_data.npy')
    idx = np.random.randint(50)
    ret = hist_fit(data[idx])

    print('Signal main Gaussian parameters:')
    print('Area: %.1f' % ret[0])
    print('Mean: %.4f' % ret[1])
    print('Sigma: %.4f' % ret[2])
    print('chi^2/ndf = %.2f' % (fcn(r.x)/(len(vy)-len(r.x))))

    xmin, xmax, xbinwidth = 0., 2., 0.02
    vx = np.linspace(xmin+xbinwidth/2,xmax-xbinwidth/2,100)
    vy = data[5]
    vyerr = vy**0.5

    fig = plt.figure(figsize=(6,6), dpi=80)
    plt.errorbar(vx, vy, vyerr, c='blue', fmt = '.')
    N,mu,sigma = ret[0],ret[1],ret[2]
    cx = np.linspace(xmin,xmax,500)
    cy = N*xbinwidth/(2.*np.pi)**0.5/sigma * np.exp(-0.5*((cx-mu)/sigma)**2)
    plt.plot(cx, cy, c='red',lw=2)
    plt.plot([xmin, xmax],[0.,0.],c='gray',lw=2)
    plt.grid()
    plt.show()

该代码有望用于不同的直方图数据集。

0 个答案:

没有答案