FFT与傅立叶分量的最小二乘拟合?

时间:2014-05-09 10:05:29

标签: python numpy scipy fft curve-fitting

所以我得到了一个信号,我已经尝试使用两种方法拟合曲线,我认为这两种方法在数值上是等价的,但显然不是。

方法1:通过最小二乘法显式拟合正弦曲线:

def curve(x, a0, a1, b1, a2, b2):
  return a0 + a1*np.cos(x/720*2*math.pi) + b1*np.sin(x/720*2*math.pi) + a2*np.cos(x/720*2*math.pi*2) + b2*np.sin(x/720*2*math.pi*2)

def fit_curve(xdata, ydata):
  guess = [10, 0, 0, 0, 0]
  params, params_covariance = optimize.curve_fit(curve, xdata, ydata, guess)
  return params, params_covariance

方法2:使用内置FFT算法做同样的事情

  f = np.fft.rfft(y,3)
  curve = np.fft.irfft(f, width)

我有两个问题。第一个是次要的,FFT是超出规模的,所以我应用缩放因子mean(y)/mean(curve)来修复它,这有点像黑客攻击。我不确定为什么会这样。

我遇到的主要问题是我认为这些结果应该会产生几乎相同的结果,但它们并不存在。显式拟合每次都比FFT结果产生更紧密的拟合 - 我的问题是,应该吗?

2 个答案:

答案 0 :(得分:2)

看看docstring for np.fft.rfft。特别是:"如果n小于输入的长度,则输入被裁剪。"当你这样做时:

    f = np.fft.rfft(y,3)

您正计算y中前三个数据点的FFT,而不是y的前三个傅里叶系数。

答案 1 :(得分:2)

一个可以使用线性代数找到离散傅立叶变换系数,但我认为它只对更好地理解DFT有用。下面的代码演示了这一点。要找到正弦系列的系数和相位需要更多的工作,但不应该太难。代码评论中引用的维基百科文章可能有所帮助。

请注意,一个人不需要scipy.optimize.curve_fit,甚至不需要线性最小二乘法。事实上,虽然我在下面使用numpy.linalg.solve,但这是不必要的,因为basis是一个单位矩阵乘以比例因子。

from __future__ import division, print_function

import numpy

# points in time series
n= 101
# final time (initial time is 0)
tfin= 10

# *end of changeable parameters*

# stepsize
dt= tfin/(n-1)
# sample count
s= numpy.arange(n)
# signal; somewhat arbitrary
y= numpy.sinc(dt*s)
# DFT
fy= numpy.fft.fft(y)
# frequency spectrum in rad/sample
wps= numpy.linspace(0,2*numpy.pi,n+1)[:-1]

# basis for DFT
# see, e.g., http://en.wikipedia.org/wiki/Discrete_Fourier_transform#equation_Eq.2
# and section "Properties -> Orthogonality"; the columns of 'basis' are the u_k vectors
# described there
basis= 1.0/n*numpy.exp(1.0j * wps * s[:,numpy.newaxis])

# reconstruct signal from DFT coeffs and basis
recon_y= numpy.dot(basis,fy)

# expect yerr to be "small"
yerr= numpy.max(numpy.abs(y-recon_y))
print('yerr:',yerr)

# find coefficients by fitting to basis
lin_fy= numpy.linalg.solve(basis,y)

# fyerr should also be "small"
fyerr= numpy.max(numpy.abs(fy-lin_fy))
print('fyerr',fyerr)

在我的系统上,这给出了

yerr: 2.20721480995e-14
fyerr 1.76885950227e-13

使用Python 2.7和3.4在Ubuntu 14.04上进行测试。