我有一组(至少3条)曲线(xy-data)。对于每条曲线,参数E和T是恒定的但不同。我正在搜索系数a,n和m,以便在所有曲线上得到最佳拟合。
y= x/E + (a/n+1)*T^(n+1)*x^m
我尝试过curve_fit,但我不知道如何获取参数E和T. 函数f(参见curve_fit文档)。此外,我不确定我是否正确理解xdata。 Doc说:M长度序列或(k,M)形数组用于函数 预测因子。什么是预测因子? 由于ydata只有一个维度,我显然无法将多条曲线输入到例程中。
所以curve_fit可能是错误的方法,但我甚至不知道搜索正确的单词的神奇词汇。我不能成为第一个处理这个问题的人。
答案 0 :(得分:1)
执行此操作的一种方法是使用scipy.optimize.leastsq
代替(curve_fit
是leastsq
周围的便利包装器)。
将x
数据堆叠在一个维度中;同上y
数据。 3个单独数据集的长度甚至不重要;我们称他们为n1
,n2
和n3
,以便您的新x
和y
形状为(n1+n2+n3,)
。
在要优化的功能中,您可以在方便时拆分数据。它不是最好的功能,但这可能有效:
def function(x, E, T, a, n, m):
return x/E + (a/n+1)*T^(n+1)*x^m
def leastsq_function(params, *args):
a = params[0]
n = params[1]
m = params[2]
x = args[0]
y = args[1]
E = args[2]
T = args[3]
n1, n2 = args[2]
yfit = np.empty(x.shape)
yfit[:n1] = function(x[:n1], E[0], T[0], a, n, m)
yfit[n1:n2] = function(x[n1:n2], E[1], T[1], a, n, m)
yfit[n2:] = function(x[n2:], E[2], T[2], a, n, m)
return y - yfit
params0 = [a0, n0, m0]
args = (x, y, (E0, E1, E2), (T0, T1, T2), (n1, n1+n2))
result = scipy.optimize.leastsq(leastsq_function, params0, args=args)
我没有测试过这个,但这是原理。您现在正在将数据拆分为3个不同的调用 要优化的函数。
请注意,scipy.optimize.leastsq
只需要一个函数来返回您想要缩小的任何值,在这种情况下,实际y
数据与拟合函数数据之间的差异。 leastsq
中的实际重要变量是您要适合的参数,而不是x
和y
数据。后者作为额外的参数传递,连同三个独立数据集的大小(我没有使用n3,为了方便,我已经完成了与n1+n2
的一些操作;请记住{{1} n1
n2
和leastsq_function
是局部变量,而不是原始变量。
因为这是一个难以适应的功能(例如,它可能没有平滑的衍生物),所以
非常重要提供良好的起始值(params0
,因此所有...0
值)。
没有跨越数量级的数据或参数。一切都在1左右(几个数量级当然可以),越好。
答案 1 :(得分:1)
谢谢你们,我发现这很有用。如果有人想要这个问题的通用解决方案,我写了一个深受上述片段启发的方法:
import numpy as np
from scipy.optimize import leastsq
def multiple_reg(x, y, f, const, params0, **kwargs):
"""Do same non-linear regression on multiple curves
"""
def leastsq_func(params, *args):
x, y = args[:2]
const = args[2:]
yfit = []
for i in range(len(x)):
yfit = np.append(yfit, f(x[i],*const[i],*params))
return y-yfit
# turn const into 2d-array if 1d is given
const = np.asarray(const)
if len(const.shape) < 2:
const = np.atleast_2d(const).T
# ensure that y is flat and x is nested
if hasattr(y[0], "__len__"):
y = [item for sublist in y for item in sublist]
if not hasattr(x[0], "__len__"):
x = np.tile(x, (len(const), 1))
x_ = [item for sublist in x for item in sublist]
assert len(x_) == len(y)
# collect all arguments in a tuple
y = np.asarray(y)
args=[x,y] + list(const)
args=tuple(args) #doesn't work if args is a list!!
return leastsq(leastsq_func, params0, args=args, **kwargs)
此函数接受不同长度的 xs 和 ys,因为它们存储在嵌套列表中而不是 numpy ndarrays 中。对于此线程中呈现的特定情况,可以像这样使用该函数:
def fit(x,T,A,n,m):
return A/(n+1.0)*np.power(T,(n+1.0))*np.power(x,m)
# prepare dataset with some noise
params0 = [0.001, 1.01, -0.8]
Ts = [10, 50]
x = np.linspace(10, 100, 100)
y = np.empty((len(Ts), len(x)))
for i in range(len(Ts)):
y[i] = fit(x, Ts[i], *params) + np.random.uniform(0, 0.01, size=len(x))
# do regression
opt_params, _ = multiple_reg(x, y, fit, Ts, params0)
通过绘制数据和回归线来验证回归
import matplotlib.pyplot as plt
for i in range(len(Ts)):
plt.scatter(x, y[i], label=f"T={Ts[i]}")
plt.plot(x, fit(x, Ts[i], *opt_params), '--k')
plt.legend(loc='best')
plt.show()
答案 2 :(得分:0)
感谢Evert的回复。
正是我需要知道的事情!!
正如你所建议的那样,我尽可能地简化了这个功能。然而,任务是找到一组A,m,n以适合所有曲线。所以我的代码看起来像这样:
import numpy
import math
from scipy.optimize import leastsq
#+++++++++++++++++++++++++++++++++++++++++++++
def fit(x,T,A,n,m):
return A/(n+1.0)*math.pow(T,(n+1.0))*numpy.power(x,m)
#+++++++++++++++++++++++++++++++++++++++++++++
def leastsq_func(params, *args):
cc=args[0] #number of curves
incs=args[1] #number of points
x=args[2]
y=args[3]
T=args[4:]
A=params[0]
n=params[1]
m=params[2]
yfit=numpy.empty(x.shape)
for i in range(cc):
v=i*incs
b=(i+1)*incs
if b<cc:
yfit[v:b]=fit(x[v:b],T[i],A,n,m)
else:
yfit[v:]=fit(x[v:],T[i],A,n,m)
return y-yfit
#+++++++++++++++++++++++++++++++++++++++++++++
Ts =[10,100,1000,10000] #4 T-values for 4 curves
incs=10 #10 datapoints in each curve
x=["measured data"] #all 40 x-values
y=["measrued data"] #all 40 y-values
x=numpy.array(x)
y=numpy.array(y)
params0=[0.001,1.01,-0.8] #parameter guess
args=[len(Ts),incs,x,y]
for c in Ts:
args.append(c)
args=tuple(args) #doesn't work if args is a list!!
result=leastsq(leastsq_func, params0, args=args)
像钟表一样工作。
起初我将Ts放在params0列表中,然后就是 在迭代期间修改导致无意义的结果。 很明显,如果你考虑一下。之后; - )
所以,Vielen Dank! 学家