我正在尝试使用python和curve_fit函数解决示例营销组合模型问题。
我需要调整两组参数,并将它们作为* arg列表添加到函数中。我可以使曲线适合一组参数(一个列表),而不是两个。
#import packages
from scipy.optimize import curve_fit
import pandas as pd
import numpy as np
from statsmodels.tsa.filters.filtertools import recursive_filter as rec
a = np.array(0).repeat(150)
b = np.array(0).repeat(150)
c = np.array(0).repeat(150)
a[0:90] = np.random.uniform(5,10,(90,))
b[50:150] = np.random.uniform(20,40,(100,))
c[30:100] = np.random.uniform(5,25,(70,))
df = pd.DataFrame({'a':a,'b':b,'c':c})
def mmm(data,*param):
dic = {}
j = 0
for i in data:
dic[i] = rec(data[i],param[j])
j += 1
return(np.sum(pd.DataFrame(dic),1))
该函数将递归过滤器应用于具有不同lambda参数的data参数中的每个字段,并返回数据帧的行总和。
kpi = mmm(df,*(0.5,0.5,0.1)) + np.random.uniform(-5,5)
将*参数传递给scipy曲线拟合函数时,必须定义一个输出函数的函数。如此处所述:Pass tuple as input argument for scipy.optimize.curve_fit
a = np.zeros(3)
def make_func():
def mmm(data,*param):
dic = {}
j = 0
for i in data:
dic[i] = rec(data[i],param[j])
j += 1
return(np.sum(pd.DataFrame(dic),1))
return(mmm)
leastsq, covar = curve_fit(make_func(),df,kpi,a)
print(leastsq)
array([0.87560795, 0.87192766, 0.84864161])
def mmm(x,*arg):
c = args[0]
a = args[1]
dic = {}
j = 0
for i in x:
dic[i] = c[j] * rec(x[i], a[j])
j += 1
return(np.sum(pd.DataFrame(dic),1))
该函数将递归过滤器应用于具有不同lambda(a)的data参数中的每个字段,将其乘以标量(c)并获取数据帧的行总和。
args = [[4,5,3],[0.2,0.4,0.5]]
kpi = mmm(df,*args) + np.random.uniform(-5,5)
args = np.zeros(6)
def make_func():
def mmm(x,*args):
c = args[0]
a = args[1]
dic = {}
j = 0
for i in x:
dic[i] = c[j] * rec( x[i], a[j])
j += 1
return(np.sum(pd.DataFrame(dic),1))
return(make_func)
leastsq, covar = curve_fit(make_func,df, kpi, p0=args)
对一个参数列表使用相同的方法会导致两个错误。错误如下:
TypeError: make_func() takes 0 positional arguments but 7 were given
为了使此代码正常工作,还有其他事情要做吗?
干杯
答案 0 :(得分:1)
在我看来,有两点是错误的根源。
1)在函数make_func()
中拟合曲线的最后一部分中,您将返回函数本身。如果将其与以前的函数定义进行比较,我认为它应该是return(mmm)
。
2)args = np.zeros(6)
产生一个零数组,您将其作为参数传递给make_func()
。然后,您将c = args[0]
和a = args[1]
分配为基本上是标量变量的c=0
和a=0
。现在,在mmm(x,*args):
函数中,您将使用dic[i] = c[j] * rec( x[i], a[j])
。因为IndexError: invalid index to scalar variable.
和a
是标量,但是您正在对它们使用索引操作,所以这里弹出c
。