分析求解三次样条插值数据导数的零点

时间:2014-08-22 21:44:18

标签: python scipy interpolation

我有一组数据,我使用3阶(立方)的单变量样条插值三次样条。我想做一种形式的峰值检测,而不是采用插值的导数和零点搜索,我只需取导数并将其插入二次方程中即可找到所有零。

这个功能到底是什么回归?因为为了生成一组位于插值上的数据,你必须为返回的项目提供一个点列表,如下所示

from numpy import linspace,exp
from numpy.random import randn
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = linspace(-3, 3, 100) # original data x axis
y = exp(-x**2) + randn(100)/10  #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s
xs = linspace(-3, 3, 1000) #values for x axis
ys = s(xs) # create new y axis
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.show()

那么这个函数究竟是什么叫做函数返回的?它列出了立方体的系数吗?如果是这样,我将如何通过区分这些值来寻找峰值?

1 个答案:

答案 0 :(得分:0)

对象sscipy.interpolate.fitpack2.UnivariateSpline的实例。

作为任何python对象,它都有一些可用于操作的属性和方法 他们。要了解使用某个python对象可以执行的操作,可以始终使用内置函数 函数typedirvars

在这种情况下,dir会有所帮助。跑吧

dir(s)

您将获得s的所有属性,这些属性为:

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__doc__',
 '__format__',
 '__getattribute__',
 '__hash__',
 '__init__',
 '__module__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_data',
 '_eval_args',
 '_from_tck',
 '_reset_class',
 '_reset_nest',
 '_set_class',
 'antiderivative',
 'derivative',
 'derivatives',
 'get_coeffs',
 'get_knots',
 'get_residual',
 'integral',
 'roots',
 'set_smoothing_factor']

Python使用名称为startin with和underscore的属性和方法的约定 是私人的,所以除非你知道自己在做什么,否则不要使用它们。但正如你所看到的那样 列表的末尾包含您想要的信息s:它包含样条系数,导数,根等。

让我们来解决这个例子:

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import UnivariateSpline
x = np.linspace(-3, 3, 100) # original data x axis
y = np.exp(-x**2) + randn(100)/10  #original data y axis
s = UnivariateSpline(x, y, s=1) # interpolation, returned to value s

# watch the changes
xs = np.linspace(-3, 3, 10000) #values for x axis
ys = s(xs) # create new y axis
ds = s.derivative(n=1)  # get the derivative
dy = ds(xs)  # compute it on xs
tol=1e-4  # stabilish a tolerance 
root_index = np.where((dy>-tol)&(dy<tol))  # find indices where dy is  close to zero within tolerance
root = xs[root_index]  # get the correspondent xs values
root = set(np.round(root, decimals=2).tolist())  # remove redundancy duo to tolerance
root = np.array(list(root))
print(root)
plt.plot(x, y, '.-')
plt.plot(xs, ys)
plt.plot(xs, dy, 'r--')  # plot the derivative
plt.vlines(root, -1, 1, lw=2, alpha=.4)  # draw vertical lines through each root
plt.hlines(0, -3, 3, lw=2, alpha=.4)  # draw a horizontal line through zero