Python拟合:优化循环

时间:2016-11-14 14:49:24

标签: python loops optimization model-fitting

我正在尝试将多个高斯拟合到给定的数据,并且该部分程序在达到第500个模型时使用大约3 GB的内存,并且我需要适合总共~2000个模型。以下是我的程序的简化版本,其中包含随机生成的数据,这些数据不能很好地适应,但它解释了时间问题:

import sys
sys.setrecursionlimit(5000)
import matplotlib.pyplot as plt
import numpy as np
import random
from random import uniform
x=[random.uniform(2200.,3100.) for p in range(0, 1000)]
y=[random.uniform(1.,1000.) for p in range(0, 1000)]

import sherpa.ui as ui
import numpy as np
ui.load_arrays(1,x,y) # 1 is the first data 
d1=ui.get_data()
d1.staterror=0.002*d1.y # define error on y just for plotting purpose, not required for fit
ui.plot_data()
ui.set_stat("leastsq") # leasr square method for fit
ui.set_model(ui.powlaw1d.pow1) # fit powerlaw.. pow1 is the shortcut name 
# ui.show_all() will show you all the parameters for the model
pow1.ref=2500
ui.fit()
# fitting  templates
x2=[random.uniform(2200.,3100.) for p in range(0, 1000)]
y2=[random.uniform(1.,1000.) for p in range(0, 1000)]

model1="pow1" # initiliaze the model for fitting all the gaussians
sign="+"
sigma=45. 
g_pos=x2 
g_ampl=[] # we will store the fit value here



ui.freeze(model1) # freeze the powerlaw 
for n in range(1,1000): # this excludes the upper limit
        ui.create_model_component("gauss1d","g{}".format(n))
        ui.set_par("g{}.pos".format(n),x2[n],frozen=True)
        ui.set_par("g{}.ampl".format(n),y2[n])
        ui.set_par("g{}.fwhm".format(n),sigma,frozen=True)
        model1=model1+sign+"g{}".format(n)
        if y2[n] == 0.:
           g_ampl.append(0.) # list zero amplitude for this model
        else:
           g=ui.create_model_component("gauss1d","g{}".format(n)) # do this to store g_ampl of this model only
           ui.set_source(model1) # overwriting with actual model
           ui.fit()
           ui.fit()
           ui.fit()
           g_ampl.append(g.ampl.val)
        ui.freeze(model1) # freeze the model and go to the next gaussian

我无法找到优化此部件的方法,以使其高效且耗时更少。任何想法,以帮助我让它运行得更快将不胜感激。

1 个答案:

答案 0 :(得分:0)

代码的问题在于它似乎不必要地存储了您想要适合的所有数据。更好的解决方案是仅存储拟合的结果。我不太了解夏尔巴。这是scipy.optimize的解决方案

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import glob, os
plt.ion()

def gaussian_func(x, a, x0, sigma,c):
    return a * np.exp(-(x-x0)**2/(2*sigma**2)) + c

# This is creates two files with random data
# Don't use for actual program
xdata = np.linspace(0, 4, 50)
ydata = gaussian_func(xdata, 2.5, 1.3, 0.5,0) + 0.2 * np.random.normal(size=len(xdata))
np.savetxt('example.dat',np.array([xdata,ydata]).T)
ydata = gaussian_func(xdata, 2.5, 2.3, 0.5,0) + 0.2 * np.random.normal(size=len(xdata))
np.savetxt('example2.dat',np.array([xdata,ydata]).T)

# Create your list of files with the data
# This examples just loads all files with .dat extension
filelist = glob.glob("*.dat") 
print(filelist)

results = []

for file in filelist:
    data = np.loadtxt(file)
    xdata = data[:,0]
    ydata = data[:,1]
    # if the error of the points is included in your file
    # exchange the following line to sigma = data[:,2]
    sigma = 0*ydata+0.2
    initial_guess = [2,1,1,0]
    popt, pcov = curve_fit(gaussian_func, xdata, ydata,p0=initial_guess,sigma=sigma)
    results.append({"filename":file,"parameters":popt,"covariance matrix":pcov})

    # This plots the result
    # Should be commented out for the large dataset
    plt.figure(1)
    plt.clf()
    plt.errorbar(xdata,ydata,sigma,fmt='ko')
    xplot = np.linspace(0,4,100)
    plt.plot(xplot,gaussian_func(xplot,*popt),'r',linewidth=2)
    plt.draw()