scipy.optimize.shgo ValueError:包含多个元素的数组的真值不明确。使用a.any()或a.all()

时间:2019-03-08 15:30:59

标签: python scipy curve-fitting

我正在尝试拟合函数y(x,T,p)以获取系数abcde,{ {1}}。 fyxT的数据是已知的。使用全局优化程序,我想找到一个很好的起点。 p似乎是唯一接受shgo的人。

constraints

有了这个,我得到import numpy as np import matplotlib.pyplot as plt from scipy.optimize import shgo # test data x = np.array([0.1,0.2,0.3,1]) T = np.array([300,300,300,300]) p = np.array([67.2,67.2,67.2,67.2]) y = np.array([30,50,55,67.2]) # function def func(pars,x,T,p): a,b,c,d,e,f = pars return x*p+x*(1-x)*(a+b*T+c*T**2+d*x+e*x*T+f*x*T**2)*p # residual def resid(pars): return ((func(pars,x,T,p) - y) ** 2).sum() # constraint: derivation is positive in every data point def der(pars): a,b,c,d,e,f = pars return -p*((3*f*T**2+3*e*T+3*d)*x**2+((2*c-2*f)*T**2+(2*b-2*e)*T-2*d+2*a)*x-c*T**2-b*T-a-1) con1 = ({'type':'ineq', 'fun':der}) # minimizer shgo bounds = [(-1,1),(-1,1),(-1,1),(-1,1),(-1,1),(-1,1)] res = shgo(resid, bounds, constraints=con1) print("a = %f , b = %f, c = %f, d = %f, e = %f, f = %f" % (res[0], res[1], res[2], res[3], res[4], res[5])) # plotting x0 = np.linspace(0, 1, 100) fig, ax = plt.subplots() fig.dpi = 80 ax.plot(x,y,'ro',label='data') for i,txt in enumerate(T): ax.annotate(txt,(x[i],y[i])) ax.plot(x0, func(res.x, x0, 300,67.2), '-', label='fit1') plt.xlabel('x') plt.ylabel('y') plt.legend() plt.show() 我不知道该错误的含义,并且具有相同错误的其他线程并不能真正帮助我理解。当我使用本地最小化器(ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 和方法scipy.optimize.minimize)时,不会出现错误。

有人可以帮助我理解我的问题,甚至帮助解决它吗? 谢谢

编辑:

cobyla

1 个答案:

答案 0 :(得分:0)

问题是der返回一个数组,而不是标量值。更改

con1 = ({'type':'ineq', 'fun':der})

con_list = [{'type':'ineq', 'fun': lambda x: der(x)[i_out]} for i_out in range(T.shape[0])]

消除错误。这样会将der的每个输出转换成自己的不等式约束。