scipy curve_fit为float32输入提供了错误的结果

时间:2012-11-02 11:08:40

标签: floating-point scipy curve-fitting

在下面的一个非常简单的例子中(拟合直线y=ax+b具有已知结果`a = b = 0),当x变量以float32-type的形式提供时,SciPy curve_fit函数会产生错误的结果:

import numpy as np
import scipy.optimize as opt

def func(x, a, b): return a*x + b

x = np.array([0,1])
y = np.array([0,0])
p0 = [0.1, 0.1]
p1 = opt.curve_fit(func,            x , y, p0=p0)
p2 = opt.curve_fit(func, np.float64(x), y, p0=p0)
p3 = opt.curve_fit(func, np.float32(x), y, p0=p0)

print '\n p1 = ',p1,'\n p2 = ',p2,'\n p3 = ',p3

产生输出:

p1 =  (array([ 0.,  0.]), inf) 
p2 =  (array([ 0.,  0.]), inf) 
p3 =  (array([ 0.1,  0.1]), inf)

为什么最后的结果(使用float32-input获得)与其他两个不同,而且 - 显然 - 错了?显然,curve_fit甚至没有执行单次迭代,因为拟合结果p3等于初始猜测p0。对我来说,这似乎是一个错误......

0 个答案:

没有答案