如何曲线拟合多个数据集Python中有4个数据集?

时间:2019-07-07 01:09:43

标签: python curve-fitting

我一直在尝试对一个线性数据和三个指数衰减数据进行曲线拟合,并且能够分别在### ---行之有效的曲线拟合--- ###下进行拟合。然后,我尝试简化它,以便可以将所有数据组合在一起,而不必分别进行处理,但是我得到了一个奇怪的曲线拟合。如何在Python中曲线拟合4组随机数据?而且,总的来说,有没有一种方法可以简化我的代码,非常感谢!

我用驱动力以4个不同的beta,b值绘制谐波振荡器的能量。

####################################################################
##----Harmonic Oscillator Demonsionless with Driving Force-----###


dU=np.zeros(2)                  #Array of zeros
def oscil(U,t):    
    #Driving force as function:
    omega= 3
    F=np.cos(omega*t)                      
    
    dU[0]=U[1]
    dU[1]=-U[0]-b*U[1]+F
    return dU

#initial condition
ic = [0,1]
t = np.linspace(0,100,1000)



#4 different graphs with different beta, B, values; B=[0,0.1,0.5,1]
b=0
soln1 = odeint(oscil, ic, t)

b=0.1
soln2 = odeint(oscil, ic, t)

b=0.5
soln3 = odeint(oscil, ic, t)

b=1
soln4 = odeint(oscil, ic, t)

#Create 4 subplots
fig, ( (plt1,plt2), (plt3,plt4) ) = plt.subplots(2,2)
fig.subplots_adjust(wspace=0.4, hspace=0.8)

plt1.plot(t, soln1[:,0], label='position')
plt1.plot(t, soln1[:,1], label='velocity')
plt1.set_title('Undamped Harmonic Oscillator')
#plt1.set_xlim(-1,4*np.pi)
plt1.set_xlabel('Time')
plt2.plot(t, soln2[:,0], label='position')
plt2.plot(t, soln2[:,1], label='velocity')
plt2.set_title('Damped HO for B=0.1')
plt2.set_xlabel('Time')
plt3.plot(t, soln3[:,0], label='position')
plt3.plot(t, soln3[:,1], label='velocity')
plt3.set_title('Damped HO for B=0.5')
plt3.set_xlim(-1,50)
plt3.set_xlabel('Time')
plt4.plot(t, soln4[:,0], label='position')
plt4.plot(t, soln4[:,1], label='velocity')
plt4.set_title('Damped HO for B=1')
#plt4.set_xlim(-1,20)
plt4.set_xlabel('Time')
plt.title('HMO with Driving Force')
plt.legend(loc='best')
plt.show()

####################################################################
# Energy = x^2 + dxdt^2
Energy1 = soln1[:,0]**2+soln1[:,1]**2
Energy2 = soln2[:,0]**2+soln2[:,1]**2
Energy3 = soln3[:,0]**2+soln3[:,1]**2
Energy4 = soln4[:,0]**2+soln4[:,1]**2

# plot
plt.plot(t, Energy1, label='B=0')
plt.plot(t, Energy2, label='B=0.1')
plt.plot(t, Energy3, label='B=0.5')
plt.plot(t, Energy4, label='B=1')
plt.title('Energy')
plt.xlabel('t')
plt.legend()
plt.show()

###################################################################
####------Curve fitting that works--------###
x=t
y1=Energy1  
y2=Energy2 
y3=Energy3 
y4=Energy4 

def mod1(x, m, b):
    return m*x+b
def mod2(x, xshift, steepness, yshift):  ###estimated function
    return xshift*np.exp(-steepness*x) + yshift
def mod3(x, xshift, steepness, yshift):  ###estimated function
    return xshift*np.exp(-steepness*x) + yshift
def mod4(x, xshift, steepness, yshift):  ###estimated function
    return xshift*np.exp(-steepness*x) + yshift

init_vals1 = [-2.4463071e-04, 1.1080361e+00 ]  
init_vals2 = [1.11913923, 0.11376818, 0.0826447 ]  
init_vals3 = [1.21244276, 0.57442609, 0.07596321 ]  
init_vals4 = [1.04292458, 0.90770071, 0.06833576]       
                
fitParams1, fitCovariances = optimize.curve_fit(mod1, x, y1, p0=init_vals1)
fitParams2, fitCovariances = optimize.curve_fit(mod2, x, y2, p0=init_vals2)
fitParams3, fitCovariances = optimize.curve_fit(mod3, x, y3, p0=init_vals3)
fitParams4, fitCovariances = optimize.curve_fit(mod4, x, y4, p0=init_vals4)
plt.plot(t,y1, t,y2, t,y3, t,y4)
plt.plot(t,mod1(x,*fitParams1), 'k' )
plt.plot(t,mod2(x,*fitParams2), 'k' )
plt.plot(t,mod3(x,*fitParams3), 'k' )
plt.plot(t,mod4(x,*fitParams4), 'k' )
plt.title('Energy')
plt.xlabel('t')
plt.show()

####################################################################
###------Curve fitting that doesn't work--------###
x=np.concatenate((t,t,t))
y=np.concatenate((Energy2, Energy3, Energy4))

def mod(x, xshift, steepness, yshift):  
    return xshift*np.exp(-steepness*x) + yshift
    
init_vals =[1.11913923, 0.11376818, 0.0826447]  
fitParams, fitCovariances = curve_fit(mod, x, y, p0=init_vals)
plt.plot(x,y)
#plt.plot(t,mod1(x,*fitParams1), 'k' )
plt.plot(x,mod(x,*fitParams), 'k-' )
plt.plot(x,mod(x,*fitParams), 'p--' )
plt.plot(x,mod(x,*fitParams), 'p-' )
plt.title('Energy')
plt.xlabel('t')
plt.show()


code that works code that doesn't work

0 个答案:

没有答案