梯度下降的纯Python实现

时间:2020-07-26 00:17:39

标签: python machine-learning gradient-descent

我尝试使用Python自己实现梯度下降。我知道也有类似的话题,但是对于我的尝试,我的猜测斜率始终可以真正接近真实斜率,但是猜测截距从未匹配甚至接近真实截距。有人知道为什么会这样吗?

此外,我阅读了很多梯度下降的文章和公式,它说,对于每次迭代,我需要将梯度乘以负学习率,然后重复直到收敛。正如您在下面的实现中所看到的,仅当我将学习率乘以梯度而不是-1时,梯度下降才起作用。这是为什么?我是否理解梯度下降是错误的,还是我的实现是错误的? (如果我将学习率和梯度乘以-1,则exam_m和exam_b很快就会溢出)

intercept = -5
slope = -4


x = []
y = []
for i in range(0, 100):
    x.append(i/300)
    y.append((i * slope + intercept)/300)

learning_rate = 0.005
# y = mx + b 
# m is slope, b is y-intercept

exam_m = 100
exam_b = 100

#iteration
#My error function is sum all (y - guess) ^2
for _ in range(20000):
    gradient_m = 0
    gradient_b = 0
    for i in range(len(x)):
        gradient_m += (y[i] - exam_m * x[i] - exam_b) * x[i]
        gradient_b += (y[i] - exam_m * x[i] - exam_b)
        #why not gradient_m -= (y[i] - exam_m * x[i] - exam_b) * x[i] like what it said in the gradient descent formula

    exam_m += learning_rate * gradient_m
    exam_b += learning_rate * gradient_b
    print(exam_m, exam_b)

1 个答案:

答案 0 :(得分:1)

溢出的原因是缺少因素(2/n)。我已经广泛展示了使用负号进行更多的说明。

import numpy as np
import matplotlib.pyplot as plt

intercept = -5
slope = -4
# y = mx + b

x = []
y = []

for i in range(0, 100):
    x.append(i/300)
    y.append((i * slope + intercept)/300)

n = len(x)
x = np.array(x)
y = np.array(y)

learning_rate = 0.05
exam_m = 0
exam_b = 0
epochs = 1000

for _ in range(epochs):
    gradient_m = 0
    gradient_b = 0
    for i in range(n):
        gradient_m -= (y[i] - exam_m * x[i] - exam_b) * x[i]
        gradient_b -= (y[i] - exam_m * x[i] - exam_b)

    exam_m = exam_m - (2/n)*learning_rate * gradient_m
    exam_b = exam_b - (2/n)*learning_rate * gradient_b

print('Slope, Intercept: ', exam_m, exam_b)

y_pred = exam_m*x + exam_b
plt.xlabel('x')
plt.ylabel('y')
plt.plot(x, y_pred, '--', color='black', label='predicted_line')
plt.plot(x, y, '--', color='blue', label='orginal_line')
plt.legend()
plt.show()

输出: Slope, Intercept: -2.421033215481844 -0.2795651072061604

enter image description here