线性回归的梯度下降不起作用

时间:2016-06-29 09:14:40

标签: java machine-learning statistics artificial-intelligence gradient-descent

我尝试使用梯度下降为一些样本数据制作线性回归程序。我得到的theta值不能最好地适合数据。我已经将数据标准化了。

public class OneVariableRegression {

    public static void main(String[] args) {

        double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};  
        double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};
        double theta0 = 0.5;
        double theta1 = 0.5;
        double temp0;
        double temp1;
        double alpha = 1.5;
        double m = x1.length;
        System.out.println(m);
        double derivative0 = 0;
        double derivative1 = 0;
        do {
                    for (int i = 0; i < x1.length; i++) {
            derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
            derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
        }
          temp0 = theta0 - (alpha * derivative0);
          temp1 = theta1 - (alpha * derivative1);
          theta0 = temp0;
          theta1 = temp1;
          //System.out.println("Derivative0 = " + derivative0);
          //System.out.println("Derivative1 = " + derivative1);
        }
        while (derivative0 > 0.0001 || derivative1 > 0.0001);
        System.out.println();
        System.out.println("theta 0 = " + theta0);
        System.out.println("theta 1 = " + theta1);
    }
}

2 个答案:

答案 0 :(得分:2)

是的,它是凸的。

您使用的导数来自平方误差函数,它是凸的,因此不接受除一个全局最小值之外的局部最小值。 (事实上​​,这类问题甚至可以接受称为正规方程的封闭形式解决方案,对于大问题而言,它只是在数值上不易处理,因此使用梯度下降)

正确答案在theta0 = 0.4895theta1 = 0.1652左右,检查任何统计计算环境都是微不足道的。 (如果你持怀疑态度,请参见答案的底部)

下面我指出你的代码中的错误,在修正错误之后,你会得到4位小数以上的正确答案。

您的实施问题:

所以你希望它能够收敛全局最小值,但你在实现中遇到了问题

每次重新计算derivative_i时,您都忘记将其重置为0(您正在做的是在do{}while()

中的迭代中累积导数

你需要在do while循环中

do {                                                                    
   derivative0 = 0;                                                      
   derivative1 = 0;

   ...
}  

接下来就是这个

derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];

x1[i]因素应仅适用于(theta0 + (theta1 * x1[i]) - y[i]))

你的尝试有点令人困惑,所以让我们以更清晰的方式写下来,这与它的数学等式(1/m)sum(y_hat_i - y_i)x_i更接近:

// You need fresh vars, don't accumulate the derivatives across gradient descent iterations

derivative0 = 0;                                                      
derivative1 = 0;

for (int i = 0; i < m; i++) {                       
    derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);          
    derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];    
}  

那应该让你足够接近,但是,我发现你的学习率alpha有点大。当它太大时,你的渐变下降会有归零没有你的全局最优,它会在那里徘徊,但不会完全存在。

double alpha = 0.5;                                                 

确认结果

运行它并将其与统计软件的答案进行比较

这是.java文件的gist on github

➜  ~ javac OneVariableRegression.java && java OneVariableRegression                                                                                                                           
20.0

theta 0 = 0.48950064086914064
theta 1 = 0.16520139788757973

我将它与R

进行了比较
> x
 [1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883
 [7] -0.59160798 -0.42257713 -0.25354628 -0.08451543  0.08451543  0.25354628
[13]  0.42257713  0.59160798  0.76063883  0.92966968  1.09870053  1.26773138
[19]  1.43676223  1.60579308
> y
 [1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53
[16] 0.61 0.65 0.68 0.74 0.87
> lm(y ~ x)

Call:
lm(formula = y ~ x)

Coefficients:
(Intercept)            x  
     0.4895       0.1652  

现在你的代码给出了至少4位小数的正确答案。

答案 1 :(得分:1)

是的,公式中存在错误。由于某种原因,你在乘法中包含了derivative0和1。这严重扭曲了结果。只需删除额外的括号,然后重试:

derivative0 = derivative0 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m);
derivative1 = derivative1 + (theta0 + (theta1 * x1[i]) - y[i]) * (1/m) * x1[i];

输出:

20.0
Derivative0 = 0.010499999999999995
Derivative1 = 0.31809711251208517
Derivative0 = 0.0052500000000000185
Derivative1 = 0.1829058398064968
Derivative0 = -0.007874999999999993
Derivative1 = -0.2129262545589219

theta 0 = 0.4881875
theta 1 = 0.06788495336050987

这更像你的预期吗?