我刚刚在youtube上使用Siraj Raval的视频启动了机器学习,并尝试了使用来自kaggle.com的数据集使用Gradient Descent执行线性回归的视频“Intro - The Intelligence of Intelligence”的挑战。这是我的代码:
"""
An Example of a Linear Regression model.
Here i am taking an example from https://www.kaggle.com/alopez247/pokemon
to find a relation between variable "Total" and "HP".
"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import sys
import os
data = pd.read_csv("./pokemon_alopez247.csv")
d = {"Total": data['Total'],
"HP": data['HP']}
smallData = pd.DataFrame(d)
test = smallData.values
epsilon = 0.001
def compute_error_for_line(b, m, points):
"""Return the Error for Line given the points."""
totalError = 0
for i in range(0, len(points)):
x = test[i, 0]
y = test[i, 1]
totalError += (y - (m * x + b)) ** 2
return totalError / float(len(points))
def step_gradient(b_current, m_current, points, learningRate):
"""Return the new b and m points."""
b_gradient = 0
m_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
error = y - ((m_current * x) + b_current)
b_gradient += -(2 / N) * error
m_gradient += -(2 / N) * x * error
new_b = b_current - (learningRate * b_gradient)
new_m = m_current - (learningRate * m_gradient)
return [new_b, new_m]
def main():
"""Return and plot function here."""
plt.figure(num=None, figsize=(20, 10), dpi=80,
facecolor='w', edgecolor='k')
plt.axis([0, 780, 0, 260])
plt.ylabel("Total")
plt.xlabel("HP")
plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)
m = 0.3
b = -30
x = np.arange(800)
y = m * x + b
for i in range(30):
error = compute_error_for_line(b, m, test)
print("error :", error)
if(error > epsilon):
y = m * x + b
plt.plot(x, y)
b, m = step_gradient(b, m, test, 0.0001)
print("b , m :", b, ",", m)
plt.pause(0.01)
plt.show()
plt.pause(0.001)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print('Interrupted')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
,输出为:
error : 193676.072288
b , m : -29.91451362 , 6.46934413315
/usr/local/lib/python3.5/dist-packages/matplotlib/backend_bases.py:2445: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
warnings.warn(str, mplDeprecation)
error : 16427.2683093
b , m : -29.9134163218 , 6.04491523016
error : 15588.2873385
b , m : -29.9065147511 , 6.07401898958
error : 15583.8939554
b , m : -29.9000125838 , 6.07192788394
error : 15583.4489928
b , m : -29.8934831191 , 6.07198242461
error : 15583.0227312
b , m : -29.8869557061 , 6.07188938575
error : 15582.5965792
b , m : -29.8804283262 , 6.07180649992
error : 15582.1704489
b , m : -29.8739011182 , 6.07172291798
error : 15581.74434
b , m : -29.8673740726 , 6.07163938615
error : 15581.3182523
b , m : -29.86084719 , 6.0715558531
error : 15580.8921858
b , m : -29.8543204704 , 6.07147232236
error : 15580.4661407
b , m : -29.8477939138 , 6.0713887937
error : 15580.0401168
b , m : -29.8412675201 , 6.07130526712
error : 15579.6141143
b , m : -29.8347412894 , 6.07122174263
error : 15579.1881329
b , m : -29.8282152217 , 6.07113822022
error : 15578.7621729
b , m : -29.821689317 , 6.0710546999
error : 15578.3362341
b , m : -29.8151635752 , 6.07097118166
error : 15577.9103166
b , m : -29.8086379963 , 6.07088766551
error : 15577.4844204
b , m : -29.8021125804 , 6.07080415145
error : 15577.0585455
b , m : -29.7955873275 , 6.07072063947
error : 15576.6326918
b , m : -29.7890622375 , 6.07063712957
error : 15576.2068594
b , m : -29.7825373104 , 6.07055362176
error : 15575.7810482
b , m : -29.7760125462 , 6.07047011604
error : 15575.3552583
b , m : -29.769487945 , 6.0703866124
error : 15574.9294897
b , m : -29.7629635067 , 6.07030311084
error : 15574.5037423
b , m : -29.7564392314 , 6.07021961138
error : 15574.0780162
b , m : -29.7499151189 , 6.07013611399
error : 15573.6523114
b , m : -29.7433911694 , 6.07005261869
error : 15573.2266278
b , m : -29.7368673827 , 6.06996912548
error : 15572.8009655
b , m : -29.730343759 , 6.06988563435
[Finished in 73.209s]
因此输出表明一切都按计划进行。但请看this。第一个蓝色是原始值,线越来越远!我尝试重写compute_error_for_line和step_gradient函数但仍然没有。 感谢您阅读到最后。
那么如何才能获得最适合我样本空间的行的参数呢?
链接到我的csv文件here(此文件将在22小时后过期)。
答案 0 :(得分:1)
plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)
看起来你交换了x和y值。如果将[1]更改为[0],反之亦然,则图表看起来非常好