python中简单实现线性回归的溢出错误

时间:2017-11-01 17:32:45

标签: python-3.x machine-learning linear-regression

我在运行代码时遇到了这个错误:

ErrorValue + =((m * x + b) - y)** 2运行时警告:在double_scalars中遇到溢出

有人可以解释我的代码有什么问题吗?如果能够对我尝试使用线性注册表的正确性以及改进代码的方法提出有用的建议,那就太棒了。

非常感谢你!

import csv
import matplotlib.pyplot as plt
from numpy import *

'''
This is a simple implementation of linear regression on correlation
hours studied by student and the marks they obtained.
'''
def run():

    points = genfromtxt("data.csv", delimiter=",")

    # x is hours studied, y is marks obtained.

    # We are applying the function: y = b + mx
    for i in range(len(points)):
        x = points[i][0]
        y = points[i][1]

    N = len(points)
    b = 0
    m = 0
    alpha = 0.001 # alpha is the learning rate
    ErrorThreshold  = 0.003
    NumberOfIterations = 1000 # We cancel the gradient descent after a number of iterations, if it still doesn't reach the threshold we want.

    sum_m = 0
    sum_b = 0

    for i in range(NumberOfIterations):
        while mean_squared_error(x,y,b,m,points) > ErrorThreshold:
            b , m  = gradient_descent(m,b,alpha,x,y,N,points)

def mean_squared_error(x,y,b,m,points):
    ErrorValue = 0
    for i in range(len(points)):
        ErrorValue += ((m*x + b) - y)**2
    return ErrorValue / len(points) 


def gradient_descent(m,b,alpha,N,x,y,points):

    #dealing with summation sign in gradient descent
    sum_m = 0
    sum_b = 0

    for i in range(len(points)):
        x = points[i][0]
        y = points[i][1]
        sum_m += m*x + b - y
        sum_b += m*x + b - y 
        #repeating just for clarification purposes.

    new_b = b - (2/N)*sum_b
    new_m = m - (((2*m)/N))*sum_m

    return new_b, new_m

if __name__ == '__main__':
    run()

1 个答案:

答案 0 :(得分:2)

当您追踪(打印)中间值时,您会看到什么?例如,在日常工作中添加几行。

def mean_squared_error(x,y,b,m,points):
    print("ENTER", x, y, b, m, len(points))
    ErrorValue = 0
    for i in range(len(points)):
        ErrorValue += ((m*x + b) - y)**2
        print("TRACE", i, ErrorValue)
    return ErrorValue / len(points) 

另外,我不确定你的计算是否正确;我想你可能想要

        ErrorValue += ((m*x[i] + b) - y[i])**2

目前,你正在增加整个向量,而不仅仅是标量,但你正在len(points)次进行。

最后,使用len(x)并且根本不传递points会更容易吗?除了篇幅之外,你没有使用points