不久前我发现了scipy.misc.derivative
。
def derivative(func, x0, dx=1.0, n=1, args=(), order=3)
它的速度和精确度令人难以置信。
所以,我决定弄清楚它是如何工作的。
但我不完全理解这个函数的代码,它在derivative
:
def central_diff_weights(Np, ndiv=1)
据我了解,该函数在高阶导数的扩展中找到系数,但它是如何做到的?
来源: https://github.com/scipy/scipy/blob/v0.18.0/scipy/misc/common.py#L179-L249
代码:
def central_diff_weights(Np, ndiv=1):
"""
Return weights for an Np-point central derivative.
Assumes equally-spaced function points.
If weights are in the vector w, then
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
Parameters
----------
Np : int
Number of points for the central derivative.
ndiv : int, optional
Number of divisions. Default is 1.
Notes
-----
Can be inaccurate for large number of points.
"""
if Np < ndiv + 1:
raise ValueError("Number of points must be at least the derivative order
+ 1.")
if Np % 2 == 0:
raise ValueError("The number of points must be odd.")
from scipy import linalg
ho = Np >> 1
x = arange(-ho,ho+1.0)
x = x[:,newaxis]
X = x**0.0
for k in range(1,Np):
X = hstack([X,x**k])
w = product(arange(1,ndiv+1),axis=0)*linalg.inv(X)[ndiv]
return w