我想直观地绘制针对给定斜率计算的误差函数的3D图和线性回归的y截距。 该图将用于说明梯度下降应用。
假设我们想用一条线模拟一组点。为此,我们将使用标准y = mx + b线方程,其中m是线的斜率,b是线的y轴截距。为了找到我们数据的最佳线,我们需要找到最佳的斜率m和y截距b值。
解决此类问题的标准方法是定义一个误差函数(也称为成本函数),用于衡量给定线的“良好”程度。此函数将接收(m,b)对并根据线与数据的拟合程度返回错误值。为了计算给定线的这个误差,我们将迭代数据集中的每个(x,y)点,并将每个点的y值和候选线的y值(在mx + b处计算)之间的平方距离求和。通常将这个距离平方以确保它是正的并使我们的误差函数可微分。在python中,计算给定行的错误将如下所示:
# y = mx + b
# m is slope, b is y-intercept
def computeErrorForLineGivenPoints(b, m, points):
totalError = 0
for i in range(0, len(points)):
totalError += (points[i].y - (m * points[i].x + b)) ** 2
return totalError / float(len(points))
由于误差函数由两个参数(m和b)组成,我们可以将它可视化为二维表面。
现在我的问题是,我们如何使用python绘制这样的3D图形?
这是构建3D绘图的骨架代码。此代码段完全不在问题上下文中,但它显示了构建3D绘图的基础知识。 对于我的例子,我需要x轴是斜率,y轴是y轴截距,z轴是误差。
有人可以帮我构建这样的图形示例吗?
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import random
def fun(x, y):
return x**2 + y
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = y = np.arange(-3.0, 3.0, 0.05)
X, Y = np.meshgrid(x, y)
zs = np.array([fun(x,y) for x,y in zip(np.ravel(X), np.ravel(Y))])
Z = zs.reshape(X.shape)
ax.plot_surface(X, Y, Z)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
上面的代码产生了以下图表,这与我正在寻找的非常相似。
答案 0 :(得分:3)
只需将fun
替换为computeErrorForLineGivenPoints
:
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import collections
def error(m, b, points):
totalError = 0
for i in range(0, len(points)):
totalError += (points[i].y - (m * points[i].x + b)) ** 2
return totalError / float(len(points))
x = y = np.arange(-3.0, 3.0, 0.05)
Point = collections.namedtuple('Point', ['x', 'y'])
m, b = 3, 2
noise = np.random.random(x.size)
points = [Point(xp, m*xp+b+err) for xp,err in zip(x, noise)]
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ms = np.linspace(2.0, 4.0, 10)
bs = np.linspace(1.5, 2.5, 10)
M, B = np.meshgrid(ms, bs)
zs = np.array([error(mp, bp, points)
for mp, bp in zip(np.ravel(M), np.ravel(B))])
Z = zs.reshape(M.shape)
ax.plot_surface(M, B, Z, rstride=1, cstride=1, color='b', alpha=0.5)
ax.set_xlabel('m')
ax.set_ylabel('b')
ax.set_zlabel('error')
plt.show()
的产率
提示:我将computeErrorForLineGivenPoints
重命名为error
。通常,不需要命名函数compute...
,因为几乎所有函数都计算一些东西。您也不需要指定" GivenPoints"因为函数签名显示points
是一个参数。如果您的程序中有其他错误函数或变量,line_error
或total_error
可能是此函数的更好名称。