我用牛顿法求解一维非线性方程。我试图弄清楚为什么牛顿方法的一个实现正好在浮点精度内收敛,而另一个则不是。
以下算法无法收敛:
而以下确实收敛:
您可以假设函数f和f'平滑且表现良好。我能想到的最好的解释是,这与所谓的迭代改进(Golub和Van Loan,1989)有某种关系。任何进一步的见解将不胜感激!
这是一个简单的python示例,说明了问题
# Python
def f(x):
return x*x-2.
def fp(x):
return 2.*x
xprev = 0.
# converges
x = 1. # guess
while x != xprev:
xprev = x
x = (x*fp(x)-f(x))/fp(x)
print(x)
# does not converge
x = 1. # guess
while x != xprev:
xprev = x
dx = -f(x)/fp(x)
x = x + dx
print(x)
注意:我知道浮点数是如何工作的(请不要将您喜欢的链接发布到一个网站,告诉我永远不要比较两个浮点数)。此外,我不是在寻找问题的解决方案,而是为了解释为什么其中一个算法收敛而不是另一个算法。
更新
正如@uhoh所指出的,在许多情况下,第二种方法不会收敛。但是,我仍然不知道为什么第二种方法在我的真实场景中比第一种方法更容易收敛。所有测试用例都有非常简单的函数f
,而真实世界f
有几百行代码(这就是为什么我不想发布它)。也许f
的复杂性很重要。如果您对此有任何其他见解,请告诉我们!
答案 0 :(得分:4)
这些方法都不完美:
两种方法都倾向于失败的一种情况是,根是在两个连续浮点数f1和f2之间的正好中间。然后,两个方法到达f1后,将尝试计算该中间值,并有很大的机会调高f2,反之亦然。
/f(x) / / / / f1 / --+----------------------+------> x / f2 / / /
答案 1 :(得分:3)
“我知道浮点数是如何工作的......”。也许浮点运算的工作比想象的要复杂得多。
这是使用牛顿方法循环迭代的典型示例。差异与epsilon的比较是“数学思维”,并且在使用浮点时可能会烧掉你。在您的示例中,您访问x的几个浮点值,然后您将陷入两个数字之间的循环中。 “浮点思维”更好地表达如下(抱歉,我的首选语言是C ++)
std::set<double> visited;
xprev = 0.0;
x = 1.0;
while (x != prev)
{
xprev = x;
dx = -F(x)/DF(x);
x = x + dx;
if (visited.find(x) != visited.end())
{
break; // found a cycle
}
visited.insert(x);
}
答案 2 :(得分:3)
我试图弄清楚为什么牛顿方法的一个实现正好在浮点精度内收敛,而另一个则没有。
从技术上讲,它并没有收敛到正确的值。尝试打印更多数字,或使用float.hex
。
第一个给出
>>> print "%.16f" % x
1.4142135623730949
>>> float.hex(x)
'0x1.6a09e667f3bccp+0'
而正确舍入的值是下一个浮点值:
>>> print "%.16f" % math.sqrt(2)
1.4142135623730951
>>> float.hex(math.sqrt(2))
'0x1.6a09e667f3bcdp+0'
第二种算法实际上是在两个值之间交替进行,因此不会收敛。
问题是由于f(x)
中的灾难性取消:因为x*x
将非常接近2,当你减去2时,结果将由舍入控制计算x*x
时出错。
答案 3 :(得分:2)
我认为试图强制完全相等(而不是错误&lt; small)总是会经常失败。在您的示例中,对于介于1和10之间的100,000个随机数(而不是2.0),第一种方法在大约1/3的时间内失败,第二种方法大约1/6的时间失败。我打赌可以预测到这一点!
这需要约30秒才能运行,结果很可爱!:
def f(x, a):
return x*x - a
def fp(x):
return 2.*x
def A(a):
xprev = 0.
x = 1.
n = 0
while x != xprev:
xprev = x
x = (x * fp(x) - f(x,a)) / fp(x)
n += 1
if n >100:
return n, x
return n, x
def B(a):
xprev = 0.
x = 1.
n = 0
while x != xprev:
xprev = x
dx = - f(x,a) / fp(x)
x = x + dx
n += 1
if n >100:
return n, x
return n, x
import numpy as np
import matplotlib.pyplot as plt
n = 100000
aa = 1. + 9. * np.random.random(n)
data_A = np.zeros((2, n))
data_B = np.zeros((2, n))
for i, a in enumerate(aa):
data_A[:,i] = A(a)
data_B[:,i] = B(a)
bins = np.linspace(0, 110, 12)
hist_A = np.histogram(data_A, bins=bins)
hist_B = np.histogram(data_B, bins=bins)
print "A: n<10: ", hist_A[0][0], " n>=100: ", hist_A[0][-1]
print "B: n<10: ", hist_B[0][0], " n>=100: ", hist_B[0][-1]
plt.figure()
plt.subplot(1,2,1)
plt.scatter(aa, data_A[0])
plt.subplot(1,2,2)
plt.scatter(aa, data_B[0])
plt.show()