需要提高fsolve的准确性以找到多个根

时间:2016-08-24 20:48:46

标签: python scipy mathematical-optimization

我使用此代码获取非线性函数的零点。 当然,该函数应该有1或3个零

import numpy as np
import matplotlib.pylab as plt
from scipy.optimize import fsolve

[a, b, c] = [5, 10, 0]

def func(x):
    return -(x+a) + b / (1 + np.exp(-(x + c)))

x = np.linspace(-10, 10, 1000)

print(fsolve(func, [-10, 0,  10]))
plt.plot(x, func(x))
plt.show()

在这种情况下,代码给出了3个预期的根,没有任何问题。 但是,如果c = -1.5,则代码会错过一个根,并且在c = -3时,它会找到一个不存在的根。

我想计算许多不同参数组合的根,因此手动更改种子不是一个实际的解决方案。

我感谢任何解决方案,技巧或建议。

2 个答案:

答案 0 :(得分:2)

您需要的是一种自动获取函数根的初始估计方法。然而,这通常是一项艰巨的任务,对于单变量,连续函数,它相当简单。我们的想法是注意到(a)这类函数可以通过适当大阶的多项式近似为任意精度,(b)有用于查找(全部)多项式根的有效算法。幸运的是,Numpy提供了执行多项式逼近和寻找多项式根的函数。

让我们考虑一个特定的功能

[a, b, c] = [5, 10, -1.5]

def func(x):
    return -(x+a) + b / (1 + np.exp(-(x + c)))

以下代码使用polyfitpoly1d通过订单{{func的多项式函数-10<x<10在感兴趣的范围(f_poly)上近似10 1}}。

x_range = np.linspace(-10,10,100)
y_range = func(x_range)

pfit = np.polyfit(x_range,y_range,10)

f_poly = np.poly1d(pfit)

如下图所示,f_poly确实是func的良好近似值。通过增加顺序可以获得更高的准确度。然而,在多项式近似中追求极端准确性是没有意义的,因为我们正在寻找根将被fsolve

稍后细化的根的近似估计。

enter image description here

多项式近似的根可以简单地获得为

roots = np.roots(pfit)
roots
  

数组([ - 10.4551 + 1.4893j,-10.4551-1.4893j,11.0027 + 0.j,            8.6679 + 2.482j,8.6679-2.482j,-5.7568 + 3.2928j,           -5.7568-3.2928j,-4.9269 + 0.j,4.7486 + 0.j,2.9158 + 0.j])

正如预期的那样,Numpy返回了10个复杂的根源。但是,我们只对区间[-10,10]内的真实根感兴趣。这些可以提取如下:

x0 = roots[np.where(np.logical_and(np.logical_and(roots.imag==0, roots.real>-10), roots.real<10))].real
x0
  

数组([ - 4.9269,4.7486,2.9158])

数组x0可以作为fsolve的初始化:

fsolve(func, x0)
  

数组([ - 4.9848,4.5462,2.7192])

备注:pychebfun包提供了一个函数,该函数直接给出一个区间内函数的所有根。它也基于执行多项式近似的思想,然而,它使用更复杂(但更有效)的方法。它会自动选择近似的最佳多项式阶数(无用户输入),多项式根实际上等于真实的(不需要通过fsolve对它们进行优化)。

这个简单的代码与fsolve

的代码具有相同的根
import pychebfun

f_cheb = pychebfun.Chebfun.from_function(func, domain = (-10,10))
f_cheb.roots()

答案 1 :(得分:2)

在两个静止点(即df/dx=0)之间,您有一个或零根。在您的情况下,可以分析地计算两个静止点:

[-c + log(1/(b - sqrt(b*(b - 4)) - 2)) + log(2),
 -c + log(1/(b + sqrt(b*(b - 4)) - 2)) + log(2)]

所以你有三个间隔,你需要找到零。使用Sympy可以避免手动进行计算。其sy.nsolve()允许在一个区间内稳健地找到零:

import sympy as sy

a, b, c, x = sy.symbols("a, b, c, x", real=True)

# The function:
f = -(x+a) + b / (1 + sy.exp(-(x + c)))
df = f.diff(x)  # calculate f' = df/dx
xxs = sy.solve(df, x)  # Solving for f' = 0 gives two solutions

# numerical values:
pp = {a: 5, b: 10, c: .5}  # values for a, b, c
fpp = f.subs(pp)
xxs_pp = [xpr.subs(pp).evalf() for xpr in xxs]  # numerical stationary points
xxs_pp.sort()  # in ascending order

# resulting intervals:
xx_low = [-1e9,      xxs_pp[0], xxs_pp[1]]
xx_hig = [xxs_pp[0], xxs_pp[1],       1e9]

# calculate roots for each interval:
xx0 = []
for xl_, xh_ in zip(xx_low, xx_hig):
    try:
        x0 = sy.nsolve(fpp, (xl_, xh_), solver="bisect")  # calculate zero
    except ValueError:  # no solution found
        continue
    xx0.append(x0)

print("The zeros are:")
print(xx0)

sy.plot(fpp)  # plot function