梯度下降不会收敛

时间:2014-12-19 23:59:27

标签: regression matlab gradient-descent

这是我自己在matlab语言中实现的梯度下降算法

 m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
    'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'}; 

y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]'; 

theta = zeros(1,width(data_training));

alpha = 1e-2; % learning rate
iter = 400;

dJ = zeros(1,width(data_training));

J_seq = zeros(1, iter);

for n = 1:iter

    err = (theta*X - y);

    for j = 1:width(data_training)
        dJ(j) = 1/m*sum(err*X(j,:)');
    end

    J = 1/(2*m)*sum((theta*X-y).^2);

    theta = theta - alpha.*dJ;

    J_seq(n) = J;

    if mod(n,100) == 0
        plot(1:iter, J_seq);
    end
end

修改 工作算法

我已将此算法应用于以下训练数据集。最后一列是输出变量。这里我们有15个不同的功能。

由于我未知的原因,当我在50次迭代后绘制成本函数J以检查它是否正在收敛时,我发现它没有收敛。你能帮我理解吗?是实施错误还是我应该做些什么?

36    27    71     8.1    3.34    11.4    81.5    3243     8.8    42.6    11.7     21     15     59    59     921.87
35    23    72    11.1    3.14      11    78.8    4281     3.6    50.7    14.4      8     10     39    57     997.88
44    29    74    10.4    3.21     9.8    81.6    4260     0.8    39.4    12.4      6      6     33    54     962.35
47    45    79     6.5    3.41    11.1    77.5    3125    27.1    50.2    20.6     18      8     24    56     982.29
43    35    77     7.6    3.44     9.6    84.6    6441    24.4    43.7    14.3     43     38    206    55     1071.3
53    45    80     7.7    3.45    10.2    66.8    3325    38.5    43.1    25.5     30     32     72    54     1030.4
43    30    74    10.9    3.23    12.1    83.9    4679     3.5    49.2    11.3     21     32     62    56      934.7
45    30    73     9.3    3.29    10.6      86    2140     5.3    40.4    10.5      6      4      4    56     899.53
36    24    70       9    3.31    10.5    83.2    6582     8.1    42.5    12.6     18     12     37    61     1001.9
36    27    72     9.5    3.36    10.7    79.3    4213     6.7      41    13.2     12      7     20    59     912.35
52    42    79     7.7    3.39     9.6    69.2    2302    22.2    41.3    24.2     18      8     27    56     1017.6
33    26    76     8.6     3.2    10.9    83.4    6122    16.3    44.9    10.7     88     63    278    58     1024.9
40    34    77     9.2    3.21    10.2      77    4101      13    45.7    15.1     26     26    146    57     970.47
35    28    71     8.8    3.29    11.1    86.3    3042    14.7    44.6    11.4     31     21     64    60     985.95
37    31    75       8    3.26    11.9    78.4    4259    13.1    49.6    13.9     23      9     15    58     958.84
35    46    85     7.1    3.22    11.8    79.9    1441    14.8    51.2    16.1      1      1      1    54      860.1
36    30    75     7.5    3.35    11.4    81.9    4029    12.4      44      12      6      4     16    58     936.23
15    30    73     8.2    3.15    12.2    84.2    4824     4.7    53.1    12.7     17      8     28    38     871.77
31    27    74     7.2    3.44    10.8      87    4834    15.8    43.5    13.6     52     35    124    59     959.22
30    24    72     6.5    3.53    10.8    79.5    3694    13.1    33.8    12.4     11      4     11    61     941.18
31    45    85     7.3    3.22    11.4    80.7    1844    11.5    48.1    18.5      1      1      1    53     891.71
31    24    72       9    3.37    10.9    82.8    3226     5.1    45.2    12.3      5      3     10    61     871.34
42    40    77     6.1    3.45    10.4    71.8    2269    22.7    41.4    19.5      8      3      5    53     971.12
43    27    72       9    3.25    11.5    87.1    2909     7.2    51.6     9.5      7      3     10    56     887.47
46    55    84     5.6    3.35    11.4    79.7    2647      21    46.9    17.9      6      5      1    59     952.53
39    29    76     8.7    3.23    11.4    78.6    4412    15.6    46.6    13.2     13      7     33    60     968.66
35    31    81     9.2     3.1      12    78.3    3262    12.6    48.6    13.9      7      4      4    55     919.73
43    32    74    10.1    3.38     9.5    79.2    3214     2.9    43.7      12     11      7     32    54     844.05
11    53    68     9.2    2.99    12.1    90.6    4700     7.8    48.9    12.3    648    319    130    47     861.83
30    35    71     8.3    3.37     9.9    77.4    4474    13.1    42.6    17.7     38     37    193    57     989.26
50    42    82     7.3    3.49    10.4    72.5    3497    36.7    43.3    26.4     15     10     34    59     1006.5
60    67    82      10    2.98    11.5    88.6    4657    13.6    47.3    22.4      3      1      1    60     861.44
30    20    69     8.8    3.26    11.1    85.4    2934     5.8      44     9.4     33     23    125    64     929.15
25    12    73     9.2    3.28    12.1    83.1    2095       2    51.9     9.8     20     11     26    50     857.62
45    40    80     8.3    3.32    10.1    70.3    2682      21    46.1    24.1     17     14     78    56     961.01
46    30    72    10.2    3.16    11.3    83.2    3327     8.8    45.3    12.2      4      3      8    58     923.23

1 个答案:

答案 0 :(得分:2)

不确定我是否遵循了你的逻辑,但很明显'e'(错误)应该被平方。

让我们看看你应该使用什么。

theta是未知数的列向量,y是度量的列向量,X是模型矩阵,其中每一行都是“示例”。所以你需要找到theta这样:

y = X*theta 

或等效地,使用优化方法查找theta最小化当前平方误差(这是使凸起优化问题的原因):

e[n] = (y - X*theta[n])

e[n]^2 --> minimize 

渐变下降使用误差函数的梯度(相对于theta)来更新theta向量:

theta[n+1] = theta[n] - alpha*2*X'*e[n]

(注意e [n]和theta [n]是向量。这是数学符号 - 不是matlab的)

所以你看到e [n]在更新方程中没有平方。