具有梯度下降的线性回归

时间:2016-12-23 08:44:38

标签: java machine-learning

我编写了以下Java程序来实现具有Gradient Descent的线性回归。代码执行但结果不准确。 y的预测值不接近y的实际值。例如,当x = 75时,预期的y = 208,但输出为y = 193.784。

class LinReg {

    double theta0, theta1;

    void buildModel(double[] x, double[] y) {
        double x_avg, y_avg, x_sum = 0.0, y_sum = 0.0;
        double xy_sum = 0.0, xx_sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            x_sum += x[i];
            y_sum += y[i];
        }
        x_avg = x_sum/n;
        y_avg = y_sum/n;

        for( i = 0; i < n; i++) {
            xx_sum += (x[i] - x_avg) * (x[i] - x_avg);
            xy_sum += (x[i] - x_avg) * (y[i] - y_avg);
        }
        theta1 = xy_sum/xx_sum;
        theta0 = y_avg - (theta1 * x_avg);
        System.out.println(theta0);
        System.out.println(theta1);

        gradientDescent(x, y, 0.1, 1500);
    }

    void gradientDescent(double x[], double y[], double alpha, int maxIter) {
        double oldtheta0, oldtheta1;
        oldtheta0 = 0.0;
        oldtheta1 = 0.0;
        int n = x.length;
        for(int i = 0; i < maxIter; i++) {
            if(hasConverged(oldtheta0, theta0) && hasConverged(oldtheta1, theta1))
                break;
            oldtheta0 = theta0;
            oldtheta1 = theta1;
            theta0 = oldtheta0 - (alpha * (summ0(x, y, oldtheta0, oldtheta1)/(double)n));
            theta1 = oldtheta1 - (alpha * (summ1(x, y, oldtheta0, oldtheta1)/(double)n));
            System.out.println(theta0);
            System.out.println(theta1);


        }
    }

    double summ0(double x[], double y[], double theta0, double theta1) {
        double sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            sum += (hypothesis(theta0, theta1, x[i]) - y[i]);
        }
        return sum;
    }

    double summ1(double x[], double y[], double theta0, double theta1) {
        double sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            sum += (((hypothesis(theta0, theta1, x[i]) - y[i]))*x[i]);
        }
        return sum;
    }

    boolean hasConverged(double oldTheta, double newTheta) {
        return ((newTheta - oldTheta) < (double)0);
    }

    double predict(double x) {
        return hypothesis(theta0, theta1, x);
    }

    double hypothesis(double thta0, double thta1, double x) {
        return (thta0 + thta1 * x);
    }
}

public class LinearRegression {
    public static void main(String[] args) {
        //Height data
        double x[] = {63.0, 64.0, 66.0, 69.0, 69.0, 71.0, 71.0, 72.0, 73.0, 75.0};
        //Weight data
        double y[] = {127.0, 121.0, 142.0, 157.0, 162.0, 156.0, 169.0, 165.0, 181.0, 208.0};
        LinReg model = new LinReg();
        model.buildModel(x, y);
        System.out.println("----------------------");
        System.out.println(model.theta0);
        System.out.println(model.theta1);
        System.out.println(model.predict(75.0));
    }
}

1 个答案:

答案 0 :(得分:4)

没有错。

我在R中验证了解决方案:

x <- c(63.0, 64.0, 66.0, 69.0, 69.0, 71.0, 71.0, 72.0, 73.0, 75.0)
y <- c(127.0, 121.0, 142.0, 157.0, 162.0, 156.0, 169.0, 165.0, 181.0, 208.0)

mod <- lm(y~x)
summary(mod)
Call:
lm(formula = y ~ x)

Residuals:
     Min       1Q   Median       3Q      Max 
-13.2339  -4.0804  -0.0963   4.6445  14.2158 

Coefficients:
             Estimate Std. Error t value Pr(>|t|)    
(Intercept) -266.5344    51.0320  -5.223    8e-04 ***
x              6.1376     0.7353   8.347 3.21e-05 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 8.641 on 8 degrees of freedom
Multiple R-squared:  0.897,   Adjusted R-squared:  0.8841 
F-statistic: 69.67 on 1 and 8 DF,  p-value: 3.214e-05

计算y值为X值为75:

-266.5344 +(6.1376 *75)
  

[1] 193.784

这是一个正确的预测。我认为混淆必须围绕回归如何运作。回归不会告诉您训练数据中与给定独立数据点对应的数据点的精确实际值。那只是一本字典,而不是统计模型(在这种情况下,它无法进行插值或推断)。

回归拟合数据的最小二乘线以估计模型方程,然后使用该模型方程来预测给定自变量值的因变量值。唯一能够准确预测训练数据中数据点的情况是你的模型过度适应(这很糟糕)。

有关更多信息和链接:

https://en.wikipedia.org/wiki/Regression_analysis