为什么SciPy的curve_fit函数关心xdata的类型?

时间:2018-12-17 13:10:05

标签: python numpy matplotlib scipy

我试图使用SciPy的curve_fit拟合一些数据,但结果却很奇怪。因此,我尝试,尝试和测试,并发现了xdata类型的问题。当xdata的类型为int时,结果变得很奇怪。但这并不适用于所有功能f。我用多项式进行了测试,直到阶数为6。从阶数3和以上开始,结果变得很奇怪。

最小示例:

import numpy as np
from scipy.optimize import curve_fit

def poly4(x, a, b, c, d, e):
    return a*np.power(x,4) + b*np.power(x,3) + c*np.power(x,2) + d*x + e

x = np.linspace(0, 9.6, 2400)
y = poly4(x, 0.03, -0.68, 5.6, -22, 1351)

x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
x2 = np.arange(0, 2400, 1, dtype=np.dtype('int'))

popt1,_ = curve_fit(poly4, x1, y)
popt2,_ = curve_fit(poly4, x2, y)

f1 = poly4(x1, *popt1)
f2 = poly4(x2, *popt2)

用这些值绘制这些值

import matplotlib.pyplot as plt
plt.plot(f1, label='f1, float range')
plt.plot(f2, label='f2, int range')
plt.legend()
plt.show()

给予

curve_fit plot with int and float range

蓝线正是结果应为的样子。用{p>查看curve_fit的输出

print(popt1)
print(popt2)

给予

  

[9.05733149e-12 -4.92513534e-08 9.73032914e-05 -9.17048770e-02      1.35100000e + 03]

     

[3.52993170e-11 -1.52725549e-10 9.38577666e-06 -3.58806105e-02      1.34272489e + 03]

为什么这些结果如此不同?好吧,显然是由于xdata的数据类型。但是,curve_fit为什么要关心xdata的数据类型?我看不到其背后的原因,也没有找到有关它的任何文档。

编辑:在python 3.6.3scipy 0.19.1上对python 3.7.1scipy 1.1.0进行了测试。两者都在Windows上。

2 个答案:

答案 0 :(得分:1)

不是execv在乎curve_fit的类型,而是函数x。 Numpy在其操作中保留数组的类型。由于您使用的是整数的n次幂,因此您很快会遇到整数溢出,从而产生意外的结果。

例如参见np.power(x,3)的输出:

poly4

enter image description here

答案 1 :(得分:-1)

您和每个无法重现您问题的人都遇到的问题是,np.dtype('int')的大小在不同平台上是不同的。如果将x1x2的声明替换为:

x1 = np.arange(0, 2400, 1, dtype=np.dtype('float'))
x2 = np.arange(0, 2400, 1, dtype=np.int32)

然后,无论平台如何,您都可以始终如一地再现奇怪的输出:

enter image description here

最初的问题是由于np.int32太小而无法处理您正在计算的某些非常大的数字,并且中间计算的值溢出了。结果是:

poly4(np.arange(2000, 2010, dtype=np.int32), 0.03, -0.68, 5.6, -22, 1351)
# array([4.60917546e+08, 3.82703937e+08, 4.34772636e+08, 3.59427040e+08,
   4.14366625e+08, 3.41894792e+08, 3.99711018e+08, 3.30118704e+08,
   3.90817330e+08, 3.24110298e+08])

看起来与以下结果完全不同:

poly4(np.arange(2000, 2010, dtype=np.int64), 0.03, -0.68, 5.6, -22, 1351)
# array([4.74582357e+11, 4.75534936e+11, 4.76488948e+11, 4.77444394e+11,
   4.78401277e+11, 4.79359597e+11, 4.80319357e+11, 4.81280557e+11,
   4.82243198e+11, 4.83207283e+11])