scipy curve_fit即使提供了很好的猜测也根本无法正确拟合?

时间:2019-11-11 13:31:36

标签: python numpy scipy curve-fitting gaussian

我正在尝试对某些数据拟合经过指数修改的高斯函数。数据位于顶部。

我有以下代码:

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import erfc

bins = [-46.82455, -46.41738, -46.01021, -45.60304, -45.19587, -44.7887,  -44.38153,
 -43.97436, -43.56719, -43.16002, -42.75285, -42.34568, -41.93851, -41.53134,
 -41.12417, -40.717,   -40.30983, -39.90266, -39.49549, -39.08832, -38.68115,
 -38.27398, -37.86681, -37.45964, -37.05247, -36.6453,  -36.23813, -35.83096,
 -35.42379, -35.01662, -34.60945, -34.20228, -33.79511, -33.38794, -32.98077,
 -32.5736,  -32.16643, -31.75926, -31.35209, -30.94492, -30.53775, -30.13058,
 -29.72341, -29.31624, -28.90907, -28.5019,  -28.09473, -27.68756, -27.28039,
 -26.87322, -26.46605, -26.05888, -25.65171, -25.24454, -24.83737, -24.4302,
 -24.02303, -23.61586, -23.20869, -22.80152, -22.39435, -21.98718, -21.58001,
 -21.17284, -20.76567, -20.3585,  -19.95133, -19.54416, -19.13699, -18.72982,
 -18.32265, -17.91548, -17.50831, -17.10114, -16.69397, -16.2868,  -15.87963,
 -15.47246, -15.06529, -14.65812, -14.25095, -13.84378, -13.43661, -13.02944,
 -12.62227, -12.2151,  -11.80793, -11.40076, -10.99359, -10.58642, -10.17925,
  -9.77208,  -9.36491,  -8.95774,  -8.55057,  -8.1434,   -7.73623,  -7.32906,
  -6.92189,  -6.51472,  -6.10755,  -5.70038,  -5.29321,  -4.88604,  -4.47887,
  -4.0717,   -3.66453,  -3.25736,  -2.85019,  -2.44302,  -2.03585,  -1.62868,
  -1.22151,  -0.81434,  -0.40717,   0.0,   0.40717,   0.81434,   1.22151,
   1.62868,   2.03585,   2.44302,   2.85019,   3.25736,   3.66453,   4.0717,
   4.47887,   4.88604,   5.29321,   5.70038,   6.10755,   6.51472,   6.92189,
   7.32906,   7.73623,   8.1434,    8.55057,   8.95774,   9.36491,   9.77208,
  10.17925,  10.58642,  10.99359,  11.40076,  11.80793,  12.2151,   12.62227,
  13.02944,  13.43661,  13.84378,  14.25095,  14.65812,  15.06529,  15.47246,
  15.87963,  16.2868,   16.69397,  17.10114,  17.50831,  17.91548,  18.32265,
  18.72982,  19.13699,  19.54416,  19.95133,  20.3585,   20.76567,  21.17284,
  21.58001,  21.98718,  22.39435,  22.80152,  23.20869,  23.61586,  24.02303,
  24.4302,   24.83737,  25.24454,  25.65171,  26.05888,  26.46605,  26.87322,
  27.28039,  27.68756,  28.09473,  28.5019,   28.90907,  29.31624,  29.72341,
  30.13058,  30.53775,  30.94492,  31.35209,  31.75926,  32.16643,  32.5736,
  32.98077,  33.38794,  33.79511,  34.20228,  34.60945,  35.01662,  35.42379,
  35.83096,  36.23813,  36.6453,   37.05247,  37.45964,  37.86681,  38.27398,
  38.68115,  39.08832,  39.49549,  39.90266,  40.30983,  40.717,    41.12417,
  41.53134,  41.93851,  42.34568,  42.75285,  43.16002,  43.56719,  43.97436,
  44.38153,  44.7887,   45.19587,  45.60304,  46.01021,  46.41738]

counts = [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 9.82318271e-04, 0.00000000e+00,
 0.00000000e+00, 9.82318271e-04, 0.00000000e+00, 9.82318271e-04,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 9.82318271e-04, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 9.82318271e-04, 0.00000000e+00, 0.00000000e+00,
 1.96463654e-03, 9.82318271e-04, 9.82318271e-04, 0.00000000e+00,
 9.82318271e-04, 1.96463654e-03, 9.82318271e-04, 9.82318271e-04,
 7.85854617e-03, 9.82318271e-03, 1.27701375e-02, 1.47347741e-02,
 1.76817289e-02, 2.75049116e-02, 3.14341847e-02, 4.32220039e-02,
 5.79567780e-02, 6.77799607e-02, 9.43025540e-02, 1.29666012e-01,
 1.48330059e-01, 1.87622790e-01, 2.07269155e-01, 2.54420432e-01,
 3.00589391e-01, 3.33005894e-01, 4.03732809e-01, 4.72495088e-01,
 5.22593320e-01, 5.99214145e-01, 6.34577603e-01, 7.04322200e-01,
 8.18271120e-01, 8.58546169e-01, 9.26326130e-01, 9.65618861e-01,
 9.35166994e-01, 9.76424361e-01, 9.39096267e-01, 1.00000000e+00,
 9.67583497e-01, 9.36149312e-01, 9.13555992e-01, 9.38113949e-01,
 8.35952849e-01, 8.31041257e-01, 8.33988212e-01, 7.54420432e-01,
 7.17092338e-01, 6.12966601e-01, 6.22789784e-01, 5.37328094e-01,
 4.76424361e-01, 4.35166994e-01, 3.89980354e-01, 3.53634578e-01,
 3.47740668e-01, 3.51669941e-01, 2.87819253e-01, 2.67190570e-01,
 3.04518664e-01, 2.60314342e-01, 2.70137525e-01, 2.65225933e-01,
 2.65225933e-01, 2.67190570e-01, 3.06483301e-01, 2.72102161e-01,
 2.61296660e-01, 2.57367387e-01, 2.45579568e-01, 2.67190570e-01,
 2.25933202e-01, 2.28880157e-01, 2.21021611e-01, 2.23968566e-01,
 1.95481336e-01, 1.80746562e-01, 1.56188605e-01, 1.53241650e-01,
 1.23772102e-01, 1.47347741e-01, 1.26719057e-01, 8.93909627e-02,
 7.17092338e-02, 8.84086444e-02, 6.28683694e-02, 6.97445972e-02,
 6.58153242e-02, 4.61689587e-02, 4.51866405e-02, 4.61689587e-02,
 4.32220039e-02, 4.61689587e-02, 4.22396857e-02, 3.92927308e-02,
 3.43811395e-02, 3.14341847e-02, 2.45579568e-02, 3.53634578e-02,
 2.94695481e-02, 3.53634578e-02, 2.75049116e-02, 2.16110020e-02,
 3.14341847e-02, 3.63457760e-02, 1.96463654e-02, 2.94695481e-02,
 2.35756385e-02, 2.84872299e-02, 2.35756385e-02, 2.55402750e-02,
 2.06286837e-02, 1.86640472e-02, 3.33988212e-02, 1.96463654e-02,
 2.35756385e-02, 1.86640472e-02, 1.96463654e-02, 2.25933202e-02,
 2.45579568e-02, 2.84872299e-02, 1.96463654e-02, 1.96463654e-02,
 1.86640472e-02, 1.76817289e-02, 1.47347741e-02, 1.96463654e-02,
 2.65225933e-02, 2.06286837e-02, 2.45579568e-02, 2.06286837e-02,
 2.35756385e-02, 1.47347741e-02, 2.06286837e-02, 6.87622790e-03,
 1.27701375e-02, 1.86640472e-02, 1.66994106e-02, 2.35756385e-02,
 1.17878193e-02, 1.96463654e-02, 9.82318271e-03, 1.47347741e-02,
 1.08055010e-02, 1.17878193e-02, 1.27701375e-02, 1.27701375e-02,
 1.37524558e-02, 1.08055010e-02, 1.27701375e-02, 5.89390963e-03,
 1.08055010e-02, 9.82318271e-03]

def exp_mod_gauss(x, m, s, l):
    y = 0.5*l*np.exp(0.5*l*(2*m+l*s*s-2*x))*erfc((m+l*s*s-x)/(np.sqrt(2)*s))
    return y
    #l=Lambda, s=Sigma, m=Mu

bins=np.asarray(bins, dtype='float')
counts=np.asarray(counts, dtype='float')

popt, pcov = curve_fit(exp_mod_gauss, bins, counts, p0=[-3.5,2.8736,0.1548])
fitted_func = exp_mod_gauss(bins, popt[0], popt[1], popt[2])
#fitted_func = exp_mod_gauss(bins, -3.5, 2.8736, 0.1548) #used for manual example
plt.plot(bins, counts, 'o', markersize=1) #plot actual counts
plt.plot(bins, fitted_func/max(fitted_func), '-') #plot fitted func/scaled
plt.show()

如果按照编写的方式使用scipy拟合运行代码,则会得到以下结果: Result with curve_fit

显然不是很好。

但是,如果我注释掉使用curve_fit参数的fit_func行,并使用我在初始猜测中提供的参数(-3.5、2.876、0.1548),则会得到以下结果: Result with manual paramaters

因此,即使我为curve_fit提供最初的猜测,这基本上就是我要寻找的答案,但是它失败了。通过在Matlab中执行完全相同的过程,我得到了很好的拟合参数,但是我不想使用Matlab。我想使用Python。

有人知道这里发生了什么吗?

非常感谢。

2 个答案:

答案 0 :(得分:2)

因此,事实证明,该配件需要EMG功能中的附加自由度才能起作用。以便可以扩展到数据。如果我将EMG函数修改为:

def exp_mod_gauss(x, b, m, s, l):
    y = b*(0.5*l*np.exp(0.5*l*(2*m+l*s*s-2*x))*erfc((m+l*s*s-x)/(np.sqrt(2)*s)))
    return y
    #l=Lambda, s=Sigma, m=Mu, #b=scaling

因此,将b项相加来缩放峰进行了排序。我提供了猜测[1,-1,1,0]现在可以满足我期望的各种数据。

答案 1 :(得分:0)

根据散点图,数据中似乎有两个单独的重叠峰。这是一个图形化的Python拟合器,使用您的数据并将其拟合到两个高斯峰的总和,并由scipy的Differential Evolution遗传算法模块提供了curve_fit()的初始参数估计。该模块使用拉丁文Hypercube算法来确保对参数空间进行彻底搜索,从而需要在搜索范围内进行搜索。在此示例中,这些边界取自最大值和最小值数据值,其中0.0用作我怀疑应该为正的参数的下限。

plot

import numpy, scipy, matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.optimize import differential_evolution
import warnings

bins = [-46.82455, -46.41738, -46.01021, -45.60304, -45.19587, -44.7887,  -44.38153,
 -43.97436, -43.56719, -43.16002, -42.75285, -42.34568, -41.93851, -41.53134,
 -41.12417, -40.717,   -40.30983, -39.90266, -39.49549, -39.08832, -38.68115,
 -38.27398, -37.86681, -37.45964, -37.05247, -36.6453,  -36.23813, -35.83096,
 -35.42379, -35.01662, -34.60945, -34.20228, -33.79511, -33.38794, -32.98077,
 -32.5736,  -32.16643, -31.75926, -31.35209, -30.94492, -30.53775, -30.13058,
 -29.72341, -29.31624, -28.90907, -28.5019,  -28.09473, -27.68756, -27.28039,
 -26.87322, -26.46605, -26.05888, -25.65171, -25.24454, -24.83737, -24.4302,
 -24.02303, -23.61586, -23.20869, -22.80152, -22.39435, -21.98718, -21.58001,
 -21.17284, -20.76567, -20.3585,  -19.95133, -19.54416, -19.13699, -18.72982,
 -18.32265, -17.91548, -17.50831, -17.10114, -16.69397, -16.2868,  -15.87963,
 -15.47246, -15.06529, -14.65812, -14.25095, -13.84378, -13.43661, -13.02944,
 -12.62227, -12.2151,  -11.80793, -11.40076, -10.99359, -10.58642, -10.17925,
  -9.77208,  -9.36491,  -8.95774,  -8.55057,  -8.1434,   -7.73623,  -7.32906,
  -6.92189,  -6.51472,  -6.10755,  -5.70038,  -5.29321,  -4.88604,  -4.47887,
  -4.0717,   -3.66453,  -3.25736,  -2.85019,  -2.44302,  -2.03585,  -1.62868,
  -1.22151,  -0.81434,  -0.40717,   0.0,   0.40717,   0.81434,   1.22151,
   1.62868,   2.03585,   2.44302,   2.85019,   3.25736,   3.66453,   4.0717,
   4.47887,   4.88604,   5.29321,   5.70038,   6.10755,   6.51472,   6.92189,
   7.32906,   7.73623,   8.1434,    8.55057,   8.95774,   9.36491,   9.77208,
  10.17925,  10.58642,  10.99359,  11.40076,  11.80793,  12.2151,   12.62227,
  13.02944,  13.43661,  13.84378,  14.25095,  14.65812,  15.06529,  15.47246,
  15.87963,  16.2868,   16.69397,  17.10114,  17.50831,  17.91548,  18.32265,
  18.72982,  19.13699,  19.54416,  19.95133,  20.3585,   20.76567,  21.17284,
  21.58001,  21.98718,  22.39435,  22.80152,  23.20869,  23.61586,  24.02303,
  24.4302,   24.83737,  25.24454,  25.65171,  26.05888,  26.46605,  26.87322,
  27.28039,  27.68756,  28.09473,  28.5019,   28.90907,  29.31624,  29.72341,
  30.13058,  30.53775,  30.94492,  31.35209,  31.75926,  32.16643,  32.5736,
  32.98077,  33.38794,  33.79511,  34.20228,  34.60945,  35.01662,  35.42379,
  35.83096,  36.23813,  36.6453,   37.05247,  37.45964,  37.86681,  38.27398,
  38.68115,  39.08832,  39.49549,  39.90266,  40.30983,  40.717,    41.12417,
  41.53134,  41.93851,  42.34568,  42.75285,  43.16002,  43.56719,  43.97436,
  44.38153,  44.7887,   45.19587,  45.60304,  46.01021,  46.41738]

counts = [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 9.82318271e-04, 0.00000000e+00,
 0.00000000e+00, 9.82318271e-04, 0.00000000e+00, 9.82318271e-04,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 9.82318271e-04, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 9.82318271e-04, 0.00000000e+00,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.82318271e-04,
 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
 0.00000000e+00, 9.82318271e-04, 0.00000000e+00, 0.00000000e+00,
 1.96463654e-03, 9.82318271e-04, 9.82318271e-04, 0.00000000e+00,
 9.82318271e-04, 1.96463654e-03, 9.82318271e-04, 9.82318271e-04,
 7.85854617e-03, 9.82318271e-03, 1.27701375e-02, 1.47347741e-02,
 1.76817289e-02, 2.75049116e-02, 3.14341847e-02, 4.32220039e-02,
 5.79567780e-02, 6.77799607e-02, 9.43025540e-02, 1.29666012e-01,
 1.48330059e-01, 1.87622790e-01, 2.07269155e-01, 2.54420432e-01,
 3.00589391e-01, 3.33005894e-01, 4.03732809e-01, 4.72495088e-01,
 5.22593320e-01, 5.99214145e-01, 6.34577603e-01, 7.04322200e-01,
 8.18271120e-01, 8.58546169e-01, 9.26326130e-01, 9.65618861e-01,
 9.35166994e-01, 9.76424361e-01, 9.39096267e-01, 1.00000000e+00,
 9.67583497e-01, 9.36149312e-01, 9.13555992e-01, 9.38113949e-01,
 8.35952849e-01, 8.31041257e-01, 8.33988212e-01, 7.54420432e-01,
 7.17092338e-01, 6.12966601e-01, 6.22789784e-01, 5.37328094e-01,
 4.76424361e-01, 4.35166994e-01, 3.89980354e-01, 3.53634578e-01,
 3.47740668e-01, 3.51669941e-01, 2.87819253e-01, 2.67190570e-01,
 3.04518664e-01, 2.60314342e-01, 2.70137525e-01, 2.65225933e-01,
 2.65225933e-01, 2.67190570e-01, 3.06483301e-01, 2.72102161e-01,
 2.61296660e-01, 2.57367387e-01, 2.45579568e-01, 2.67190570e-01,
 2.25933202e-01, 2.28880157e-01, 2.21021611e-01, 2.23968566e-01,
 1.95481336e-01, 1.80746562e-01, 1.56188605e-01, 1.53241650e-01,
 1.23772102e-01, 1.47347741e-01, 1.26719057e-01, 8.93909627e-02,
 7.17092338e-02, 8.84086444e-02, 6.28683694e-02, 6.97445972e-02,
 6.58153242e-02, 4.61689587e-02, 4.51866405e-02, 4.61689587e-02,
 4.32220039e-02, 4.61689587e-02, 4.22396857e-02, 3.92927308e-02,
 3.43811395e-02, 3.14341847e-02, 2.45579568e-02, 3.53634578e-02,
 2.94695481e-02, 3.53634578e-02, 2.75049116e-02, 2.16110020e-02,
 3.14341847e-02, 3.63457760e-02, 1.96463654e-02, 2.94695481e-02,
 2.35756385e-02, 2.84872299e-02, 2.35756385e-02, 2.55402750e-02,
 2.06286837e-02, 1.86640472e-02, 3.33988212e-02, 1.96463654e-02,
 2.35756385e-02, 1.86640472e-02, 1.96463654e-02, 2.25933202e-02,
 2.45579568e-02, 2.84872299e-02, 1.96463654e-02, 1.96463654e-02,
 1.86640472e-02, 1.76817289e-02, 1.47347741e-02, 1.96463654e-02,
 2.65225933e-02, 2.06286837e-02, 2.45579568e-02, 2.06286837e-02,
 2.35756385e-02, 1.47347741e-02, 2.06286837e-02, 6.87622790e-03,
 1.27701375e-02, 1.86640472e-02, 1.66994106e-02, 2.35756385e-02,
 1.17878193e-02, 1.96463654e-02, 9.82318271e-03, 1.47347741e-02,
 1.08055010e-02, 1.17878193e-02, 1.27701375e-02, 1.27701375e-02,
 1.37524558e-02, 1.08055010e-02, 1.27701375e-02, 5.89390963e-03,
 1.08055010e-02, 9.82318271e-03]

xData = numpy.array(bins)
yData = numpy.array(counts)


def func(X, a, b, c, f, g, h): # sum of two gaussian peaks
    # a, b, c and f, g, h are the fitted parameters for the two peaks
    return a * numpy.exp(-0.5 * ((X-b)/c)**2)  +  f * numpy.exp(-0.5 * ((X-g)/h)**2)


# function for genetic algorithm to minimize (sum of squared error)
def sumOfSquaredError(parameterTuple):
    warnings.filterwarnings("ignore") # do not print warnings by genetic algorithm
    val = func(xData, *parameterTuple)
    return numpy.sum((yData - val) ** 2.0)


def generate_Initial_Parameters():
    # min and max used for bounds
    maxX = max(xData)
    minX = min(xData)
    #maxY = max(yData)
    #minY = min(yData)

    parameterBounds = []

    parameterBounds.append([0.0, maxX]) # search bounds for a, positive
    parameterBounds.append([minX, maxX]) # search bounds for b
    parameterBounds.append([0.0, maxX]) # search bounds for c, positive

    parameterBounds.append([0.0, maxX]) # search bounds for f, positive
    parameterBounds.append([minX, maxX]) # search bounds for g
    parameterBounds.append([0.0, maxX]) # search bounds for h, positive

    # "seed" the numpy random number generator for repeatable results
    result = differential_evolution(sumOfSquaredError, parameterBounds, seed=3)
    return result.x

# by default, differential_evolution completes by calling curve_fit() using parameter bounds
fittedParameters = generate_Initial_Parameters()
print('Fitted parameters:', fittedParameters)
print()

modelPredictions = func(xData, *fittedParameters) 

absError = modelPredictions - yData

SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))

print()
print('RMSE:', RMSE)
print('R-squared:', Rsquared)

print()


##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth, graphHeight):
    f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
    axes = f.add_subplot(111)

    # first the raw data as a scatter plot
    axes.plot(xData, yData,  'D')

    # create data for the fitted equation plot
    xModel = numpy.linspace(min(xData), max(xData))
    yModel = func(xModel, *fittedParameters)

    # now the model as a line plot
    axes.plot(xModel, yModel)

    axes.set_xlabel('X Data') # X axis data label
    axes.set_ylabel('Y Data') # Y axis data label

    plt.show()
    plt.close('all') # clean up after using pyplot

graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth, graphHeight)