numpy的polyfit将无法正确地适应抛物线

时间:2019-05-02 18:36:51

标签: python numpy

bad fit image

我正在尝试使用polyfit将抛物线拟合到“数据”中的数据点集。我的程序正在为我尝试的其他数据集工作,但不适用于我的特定数据集。

我尝试确保x数据按升序正确排序,并且尝试将数据拟合到excel中。在excel中看起来合适。

data = [[0.16888549099999922, 7.127084528823611], [0.16888549199999922, 6.993992044491425], [0.16888549299999922, 6.866362061761786], [0.16888549399999922, 6.744197905413327], [0.16888549499999922, 6.627501951010496], [0.16888549599999922, 6.516275651493945], [0.16888549699999922, 6.41051952560987], [0.16888549799999922, 6.310233194927246], [0.16888549899999922, 6.215415356951307], [0.16888549999999922, 6.1260638293986895], [0.16888550099999922, 6.042175535068139], [0.16888550199999922, 5.963746518271748], [0.16888550299999922, 5.890771966813277], [0.16888550399999921, 5.823246197254835], [0.16888550499999921, 5.761162692791979], [0.1688855059999992, 5.704514090084273], [0.1688855069999992, 5.653292206179698], [0.1688855079999992, 5.6074880385927806], [0.1688855089999992, 5.567091780688852], [0.1688855099999992, 5.532092833944616], [0.1688855109999992, 5.502479813247546], [0.1688855119999992, 5.4782405647173915], [0.1688855129999992, 5.459362171880384], [0.1688855139999992, 5.4458309690410776], [0.1688855149999992, 5.437632552360041], [0.1688855159999992, 5.434751791325787], [0.1688855169999992, 5.43717284006055], [0.1688855179999992, 5.444879149278103], [0.1688855189999992, 5.457853477721287], [0.1688855199999992, 5.47607790389057], [0.1688855209999992, 5.4995338388052435], [0.1688855219999992, 5.52820203573491], [0.1688855229999992, 5.562062602591132], [0.1688855239999992, 5.601095018180885], [0.1688855249999992, 5.645278137030775], [0.1688855259999992, 5.694590207075108], [0.1688855269999992, 5.7490088760994285], [0.1688855279999992, 5.808511211230821], [0.1688855289999992, 5.873073702228439], [0.1688855299999992, 5.942672282488598], [0.1688855309999992, 6.017282332411738], [0.1688855319999992, 6.096878700665825], [0.1688855329999992, 6.1814357028181135], [0.1688855339999992, 6.27092714427376], [0.1688855349999992, 6.365326333275836], [0.1688855359999992, 6.464606084709255], [0.1688855369999992, 6.56873873137477], [0.1688855379999992, 6.677696151240223], [0.1688855389999992, 6.791449748897025], [0.1688855399999992, 6.909970506323868]]

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')

x = list()
y = list()
for i in data:
    x.append(i[0])
    y.append(i[1])

fit = np.polyfit(x, y, 2)
f = np.poly1d(fit)

plt.scatter(x,y)
plt.plot(x,f(x))
plt.xlim(0.16888549099999922 - 0.000000001,0.1688855399999992 + 0.000000001)
plt.show()

我看起来像一条线。存在拟合条件差的错误。我无法弄清楚为什么这会在我的代码中发生,因此我设置了这个小代码以清楚地说明问题。有人可以帮我吗?

1 个答案:

答案 0 :(得分:6)

x值之间的差异很小:

In [40]: np.diff(x).max()
Out[54]: 9.999999994736442e-10

当输入不是非常小的时候,一些数字配方会更好地工作。 (对于 例如,以固定步长(例如0.1)开始的算法可能很好 适用于大多数单位大小的数据,但完全超出了您的最佳系数 情况。)

如果您将数据标准化:

x = (x - x.mean())/x.std()

那么您将得到一个更明智的结果:


data = [[0.16888549099999922, 7.127084528823611], [0.16888549199999922, 6.993992044491425], [0.16888549299999922, 6.866362061761786], [0.16888549399999922, 6.744197905413327], [0.16888549499999922, 6.627501951010496], [0.16888549599999922, 6.516275651493945], [0.16888549699999922, 6.41051952560987], [0.16888549799999922, 6.310233194927246], [0.16888549899999922, 6.215415356951307], [0.16888549999999922, 6.1260638293986895], [0.16888550099999922, 6.042175535068139], [0.16888550199999922, 5.963746518271748], [0.16888550299999922, 5.890771966813277], [0.16888550399999921, 5.823246197254835], [0.16888550499999921, 5.761162692791979], [0.1688855059999992, 5.704514090084273], [0.1688855069999992, 5.653292206179698], [0.1688855079999992, 5.6074880385927806], [0.1688855089999992, 5.567091780688852], [0.1688855099999992, 5.532092833944616], [0.1688855109999992, 5.502479813247546], [0.1688855119999992, 5.4782405647173915], [0.1688855129999992, 5.459362171880384], [0.1688855139999992, 5.4458309690410776], [0.1688855149999992, 5.437632552360041], [0.1688855159999992, 5.434751791325787], [0.1688855169999992, 5.43717284006055], [0.1688855179999992, 5.444879149278103], [0.1688855189999992, 5.457853477721287], [0.1688855199999992, 5.47607790389057], [0.1688855209999992, 5.4995338388052435], [0.1688855219999992, 5.52820203573491], [0.1688855229999992, 5.562062602591132], [0.1688855239999992, 5.601095018180885], [0.1688855249999992, 5.645278137030775], [0.1688855259999992, 5.694590207075108], [0.1688855269999992, 5.7490088760994285], [0.1688855279999992, 5.808511211230821], [0.1688855289999992, 5.873073702228439], [0.1688855299999992, 5.942672282488598], [0.1688855309999992, 6.017282332411738], [0.1688855319999992, 6.096878700665825], [0.1688855329999992, 6.1814357028181135], [0.1688855339999992, 6.27092714427376], [0.1688855349999992, 6.365326333275836], [0.1688855359999992, 6.464606084709255], [0.1688855369999992, 6.56873873137477], [0.1688855379999992, 6.677696151240223], [0.1688855389999992, 6.791449748897025], [0.1688855399999992, 6.909970506323868]]

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')

x, y = map(np.array, zip(*data))
x_normalized = (x - x.mean())/x.std()
fit_normalized = np.polyfit(x_normalized, y, 2)

f = np.poly1d(fit_normalized)
plt.scatter(x, y, marker='o', c='red')
plt.plot(x,f(x_normalized), c='black')
plt.xlim(0.16888549099999922 - 0.000000001,0.1688855399999992 + 0.000000001)
plt.show()

产量 enter image description here


以上,fit_normalized = np.polyfit(x_normalized, y, 2)计算有关归一化x数据的系数。要找到与原始数据有关的系数, 做侧面计算:

让我们

m, s = x.mean(), x.std()
x_normalized = (x - m)/s

您可以将其视为坐标变换。然后

y = a * x_normalized**2 + b * x_normalized + c
y = a * ((x - m)/s)**2 + b * ((x - m)/s) + c

现在,您可以展开和收集项以找到关于x的系数。或者,您可以使用sympy之类的符号数学包:

In [55]: import sympy as sym
In [57]: x, a, b, c, m, s = sym.symbols('x a b c m s')
In [104]: sym.poly(a * ((x - m)/s)**2 + b * ((x - m)/s) + c, x).coeffs()
Out[104]: [a/s**2, (-2*a*m + b*s)/s**2, (a*m**2 - b*m*s + c*s**2)/s**2]

表明

y = a/s**2 * x**2 + (-2*a*m + b*s)/s**2 * x + (a*m**2 - b*m*s + c*s**2)/s**2

只是为了证明上述计算得出了合理的拟合结果:

data = [[0.16888549099999922, 7.127084528823611], [0.16888549199999922, 6.993992044491425], [0.16888549299999922, 6.866362061761786], [0.16888549399999922, 6.744197905413327], [0.16888549499999922, 6.627501951010496], [0.16888549599999922, 6.516275651493945], [0.16888549699999922, 6.41051952560987], [0.16888549799999922, 6.310233194927246], [0.16888549899999922, 6.215415356951307], [0.16888549999999922, 6.1260638293986895], [0.16888550099999922, 6.042175535068139], [0.16888550199999922, 5.963746518271748], [0.16888550299999922, 5.890771966813277], [0.16888550399999921, 5.823246197254835], [0.16888550499999921, 5.761162692791979], [0.1688855059999992, 5.704514090084273], [0.1688855069999992, 5.653292206179698], [0.1688855079999992, 5.6074880385927806], [0.1688855089999992, 5.567091780688852], [0.1688855099999992, 5.532092833944616], [0.1688855109999992, 5.502479813247546], [0.1688855119999992, 5.4782405647173915], [0.1688855129999992, 5.459362171880384], [0.1688855139999992, 5.4458309690410776], [0.1688855149999992, 5.437632552360041], [0.1688855159999992, 5.434751791325787], [0.1688855169999992, 5.43717284006055], [0.1688855179999992, 5.444879149278103], [0.1688855189999992, 5.457853477721287], [0.1688855199999992, 5.47607790389057], [0.1688855209999992, 5.4995338388052435], [0.1688855219999992, 5.52820203573491], [0.1688855229999992, 5.562062602591132], [0.1688855239999992, 5.601095018180885], [0.1688855249999992, 5.645278137030775], [0.1688855259999992, 5.694590207075108], [0.1688855269999992, 5.7490088760994285], [0.1688855279999992, 5.808511211230821], [0.1688855289999992, 5.873073702228439], [0.1688855299999992, 5.942672282488598], [0.1688855309999992, 6.017282332411738], [0.1688855319999992, 6.096878700665825], [0.1688855329999992, 6.1814357028181135], [0.1688855339999992, 6.27092714427376], [0.1688855349999992, 6.365326333275836], [0.1688855359999992, 6.464606084709255], [0.1688855369999992, 6.56873873137477], [0.1688855379999992, 6.677696151240223], [0.1688855389999992, 6.791449748897025], [0.1688855399999992, 6.909970506323868]]

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')

x, y = map(np.array, zip(*data))
m, s = x.mean(), x.std()
x_normalized = (x - m)/s

fit_normalized = a, b, c = np.polyfit(x_normalized, y, 2)
fit = (a / s**2), (-2*a*m/s**2 + b/s), a*m**2/s**2 - b*m/s + c

print(fit_normalized)
# [ 0.54960506 -0.05561036  5.43651191]

print(fit)
# (2639159957611392.0, -891431783709882.0, 75274958487869.69)

f = np.poly1d(fit)
plt.scatter(x, y, marker='o', c='red')
plt.plot(x, f(x), c='black')
plt.xlim(0.16888549099999922 - 0.000000001,0.1688855399999992 + 0.000000001)
plt.show()

产生大致相同的情节。请注意,s非常小,因此除以s会导致数字很大,而s中的小误差将导致系数的大误差。还请记住,NumPy浮点数组具有有限精度的dtype。 因此,根据您的实际数据集,估算的系数可能不是很准确。您可能需要自己使用一个任意精度的数学包(例如decimalgmpy2)来计算系数,以取得更好的效果,但是如果您的输入未知,则需要再次计算,对低精度数据进行精确计算将无济于事。