我想用拉格朗日方法插值多项式,但这段代码不起作用:
def interpolate(x_values, y_values):
def _basis(j):
p = [(x - x_values[m])/(x_values[j] - x_values[m]) for m in xrange(k + 1) if m != j]
return reduce(operator.mul, p)
assert len(x_values) != 0 and (len(x_values) == len(y_values)), 'x and y cannot be empty and must have the same length'
k = len(x_values)
return sum(_basis(j) for j in xrange(k))
我跟着Wikipedia,但是当我运行它时,我在第3行收到一个IndexError!
由于
答案 0 :(得分:5)
尝试
def interpolate(x, x_values, y_values):
def _basis(j):
p = [(x - x_values[m])/(x_values[j] - x_values[m]) for m in xrange(k) if m != j]
return reduce(operator.mul, p)
assert len(x_values) != 0 and (len(x_values) == len(y_values)), 'x and y cannot be empty and must have the same length'
k = len(x_values)
return sum(_basis(j)*y_values[j] for j in xrange(k))
您可以按照以下方式确认:
>>> interpolate(1,[1,2,4],[1,0,2])
1.0
>>> interpolate(2,[1,2,4],[1,0,2])
0.0
>>> interpolate(4,[1,2,4],[1,0,2])
2.0
>>> interpolate(3,[1,2,4],[1,0,2])
0.33333333333333331
因此,结果是基于经过给定点的多项式的插值。在这种情况下,3个点定义一个抛物线,前3个测试表明为给定的x_value返回了所述的y_value。
答案 1 :(得分:3)
检查索引,维基百科说“k + 1个数据点”,但是如果你完全按照公式,那么你设置的k = len(x_values)
应该是k = len(x_values) - 1
。
答案 2 :(得分:1)
我聚会晚了将近十年,但是我发现这是在寻找一种简单实现拉格朗日插值的方法。 @smichr的答案很好,但是Python有点过时了,我还想要一些可以与np.ndarrays
一起很好地工作的东西,因此我可以轻松地进行绘制。也许其他人会发现这很有用:
import numpy as np
import matplotlib.pyplot as plt
class LagrangePoly:
def __init__(self, X, Y):
self.n = len(X)
self.X = np.array(X)
self.Y = np.array(Y)
def basis(self, x, j):
b = [(x - self.X[m]) / (self.X[j] - self.X[m])
for m in range(self.n) if m != j]
return np.prod(b, axis=0) * self.Y[j]
def interpolate(self, x):
b = [self.basis(x, j) for j in range(self.n)]
return np.sum(b, axis=0)
X = [-9, -4, -1, 7]
Y = [5, 2, -2, 9]
plt.scatter(X, Y, c='k')
lp = LagrangePoly(X, Y)
xx = np.arange(-100, 100) / 10
plt.plot(xx, lp.basis(xx, 0))
plt.plot(xx, lp.basis(xx, 1))
plt.plot(xx, lp.basis(xx, 2))
plt.plot(xx, lp.basis(xx, 3))
plt.plot(xx, lp.interpolate(xx), linestyle=':')
plt.show()
答案 3 :(得分:0)
此代码与Python 3
兼容:
def Lagrange (Lx, Ly):
x=sympy.symbols('x')
if len(Lx)!= len(Ly):
return 1
y=0
for k in range ( len(Lx) ):
t=1
for j in range ( len(Lx) ):
if j != k:
t=t* ( (x-Lx[j]) /(Lx[k]-Lx[j]) )
y+= t*Ly[k]
return y