使用Scipy curve_fit在Python中更改参数数量来创建拟合函数的更好方法

时间:2016-05-26 14:38:39

标签: python scipy curve-fitting

我正在寻找一种更好的方法来使用scipy的curve_fit()

我目前正在使用它来拟合参数和矢量x的线性组合。

例如,这是一个拟合函数,我尝试使用5个参数进行拟合,m0-m4:

def degFour(x, m0, m1, m2, m3, m4):
    return x[0]*m0 + x[1]*m1 + x[2]*m2 + x[3]*m3 + x[4]*m4

我使用相同的模式将更多这些内容添加到degTen。它确实有用。

我的x矢量:

[[ 1.          1.          1.          1.          1.        ]
 [ 1.          0.99990931  0.99963727  0.99918392  0.99854935]
 [ 1.          0.94872591  0.80016169  0.56954235  0.28051747]
 [ 1.          0.84717487  0.43541052 -0.10943716 -0.62083535]
 [ 1.          0.77991807  0.21654439 -0.44214431 -0.90621706]
 [ 1.          0.73162055  0.07053725 -0.62840754 -0.99004899]
 [ 1.          0.68866877 -0.05147065 -0.75956123 -0.99470154]
 [ 1.          0.64892616 -0.15778967 -0.85371386 -0.95020484]
 [ 1.          0.6114128  -0.25234877 -0.91999134 -0.8726402 ]
 [ 1.          0.57600247 -0.33644232 -0.96358568 -0.77361313]
 [ 1.          0.54225052 -0.41192874 -0.98898767 -0.66062942]
 [ 1.          0.29541145 -0.82546415 -0.78311458  0.36278212]
 [ 1.          0.09546594 -0.98177251 -0.28291761  0.92775452]
 [ 1.         -0.07539697 -0.9886306   0.22447646  0.95478091]
 [ 1.         -0.22050008 -0.90275943  0.61861713  0.62994918]
 [ 1.         -0.33964821 -0.76927818  0.86221613  0.18357784]
 [ 1.         -0.54483185 -0.40631651  0.9875802  -0.66981378]
 [ 1.         -0.71937092  0.03498904  0.66903073 -0.99755153]
 [ 1.         -1.          1.         -1.          1.        ]]

我的数据:

[  3.50032   3.5007    3.6328    3.94564   4.12814   4.2651    4.39586
   4.51982   4.64394   4.76738   4.88654   5.90314   6.93304   7.99074
   9.04278  10.02426  12.01392  14.0592   18.1689 ]

使用curve_fit(degFour,xdata.T,ydata),我得到了正确的系数:

[ 9.14562709 -7.05004692  1.66932215 -0.27868686  0.02097462]

我根据程度重新创建x数据,因此我将始终传递具有正确形状的数据。

我尝试了一个关于变量输入参数的fbstj's answer版本。

我用过这个:

def vararg(x, *args):
    return sum(a * x[i] for i, a in enumerate(args))

结束了这个:

Traceback (most recent call last):
  File "D:/Libraries/Desktop/PScratch2/vararg.py", line 18, in <module>
    print(curve_fit(vararg, deg4kary.T, deg4ydata))
  File "C:\Python35\lib\site-packages\scipy\optimize\minpack.py", line 606, in curve_fit
    raise ValueError("Unable to determine number of fit parameters.")
ValueError: Unable to determine number of fit parameters.

从跟踪中可以看出,我刚刚传递了函数本身。我被卡住了。

1 个答案:

答案 0 :(得分:2)

您正在为数据拟合多元线性模型。这可以表示为x向量与形状(npoints, nparams)和单个(nparams,)系数向量之间的点积,例如m

def linear(x, m):
    return x.dot(m)

x = np.random.randn(100, 5)
m = np.random.randn(5)

y = linear(x, m)

确实没有必要使用curve_fit来获取m系数 - 使用np.linalg.lstsq来解决线性系统更简单,更有效率这样:

m_hat, residuals, rank, singular_vals = np.linalg.lstsq(x, y)

此处,m_hat将是(n_params,)向量,其中包含m0m1m2等的最小二乘估计值。这适用于任何系数数。