使用类的__call__方法作为Numpy curve_fit的输入

时间:2013-01-02 15:22:11

标签: python class numpy call curve-fitting

我想使用类的__call__方法作为Numpy curve_fit函数的输入,因为我的功能和数据准备过程相当复杂(将分析模型数据拟合到某些测量中)。通过定义一个函数它可以正常工作,但我不能让它与类一起工作。

要重新创建我的问题,您可以运行:

import numpy as np
from scipy.optimize import curve_fit

#WORKS:
#def goal(x,a1,a2,a3,a4,a5):
#    y=a1*x**4*np.sin(x)+a2*x**3+a3*x**2+a4*x+a5
#    return y

# DOES NOT WORK:
class func():
    def __call__(self,x,a1,a2,a3,a4,a5):
        y=a1*x**4*np.sin(x)+a2*x**3+a3*x**2+a4*x+a5
        return y    

goal=func()

#data prepraration ***********
xdata=np.linspace(0,50,100)
ydata=goal(xdata,-2.1,-3.5,6.6,-1,2)
# ****************************

popt, pcov = curve_fit(goal, xdata, ydata)
print 'optimial parameters',popt
print 'The estimated covariance of optimial parameters',pcov

我得到的错误是:

Traceback (most recent call last):
  File "D:\...some path...\test_minimizacija.py", line 35, in <module>
    popt, pcov = curve_fit(goal, xdata, ydata)
  File "C:\Python26\lib\site-packages\scipy\optimize\minpack.py", line 412, in curve_fit
    args, varargs, varkw, defaults = inspect.getargspec(f)
  File "C:\Python26\lib\inspect.py", line 803, in getargspec
    raise TypeError('arg is not a Python function')
TypeError: arg is not a Python function

我该如何使这项工作?

1 个答案:

答案 0 :(得分:3)

简单(虽然不漂亮),只需将其更改为:

popt, pcov = curve_fit(goal.__call__, xdata, ydata)

有趣的是,numpy会强制您将函数对象传递给curve_fit而不是任意可调用...

快速检查curve_fit的来源,似乎另一种解决方法可能是:

popt,pcov = curve_fit(goal, xdata, ydata, p0=[1]*5)

此处,p0是拟合参数的初始猜测。问题似乎是scipy检查函数的参数,以便它知道如果您实际上没有提供参数作为初始猜测,将使用多少参数。在这里,由于我们有5个参数,我的初始猜测是所有长度为5的列表。(scipy默认使用1,如果你也不提供猜测。)