线性回归中的梯度下降出错了

时间:2015-03-09 12:53:29

标签: matlab machine-learning gradient-descent

我实际上想要使用线性模型拟合一组'sin'数据,但事实证明在每次迭代期间损失函数变大。我的代码下面有什么问题吗? (梯度下降法)

这是我在Matlab中的代码

m=20;
rate = 0.1;
x = linspace(0,2*pi,20);
x = [ones(1,length(x));x]
y = sin(x);
w = rand(1,2);
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2 
    total_loss = [total_loss loss];
    **gradient = (h-y)*x'./m ;**
    w = w - rate.*gradient;
end

这是我想要的数据 y=sin(x)

1 个答案:

答案 0 :(得分:1)

您的代码没有问题。使用当前框架,如果您可以以y = m*x + b的形式定义数据,那么此代码就足够了。我实际上通过几个测试运行它,在那里我定义了一个线的方程并向其添加一些高斯随机噪声(幅度= 0.1,平均值= 0,标准偏差= 1)。

但是,我要提到的一个问题是,如果您查看正弦数据,可以在[0,2*pi]之间定义一个域。如您所见,您有多个x值,这些值会映射到相同的y值,但幅度不同。例如,在x = pi/2我们得到1但在x = -3*pi/2得到-1。这种高变异性对于线性回归来说不是好兆头,所以我的一个建议就是限制你的域......所以像[0, pi]这样。它可能不会收敛的另一个原因是你选择的学习率太高了。我把它设置为0.01之类的低点。正如你在评论中提到的那样,你已经想出来了!

但是,如果您想使用线性回归拟合非线性数据,则必须包含更高阶的术语以说明可变性。因此,尝试包括二阶和/或三阶项。这可以通过修改x矩阵来完成:

x = [ones(1,length(x)); x; x.^2; x.^3];

如果你还记得,假设函数可以表示为线性项的总和:

h(x) = theta0 + theta1*x1 + theta2*x2 + ... + thetan*xn

在我们的例子中,每个theta项将构建我们多项式的高阶项。 x2x^2x3x^3。因此,我们仍然可以在这里使用梯度下降的定义进行线性回归。

我也将控制随机生成种子(通过rng),以便您可以产生我得到的相同结果:

clear all; 
close all;
rng(123123);
total_loss = [];
m = 20;
x = linspace(0,pi,m); %// Change
y = sin(x);
w = rand(1,4); %// Change
rate = 0.01; %// Change
x = [ones(1,length(x)); x; x.^2; x.^3]; %// Change - Second and third order terms
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2;
    total_loss = [total_loss loss];
    % gradient is now in a different expression
    gradient = (h-y)*x'./m ; % sum all in each iteration, it's a batch gradient
    w = w - rate.*gradient;
end

如果我们尝试这样做,我们会获得w(您的参数):

>> format long g;
>> w


w =

  Columns 1 through 3

         0.128369521905694         0.819533906064327       -0.0944622478526915

  Column 4

       -0.0596638117151464

此后我的最终损失是:

loss =

       0.00154350916582836

这意味着我们的线的等式是:

y = 0.12 + 0.819x - 0.094x^2 - 0.059x^3

如果我们用正弦数据绘制线的这个等式,这就是我们得到的:

xval = x(2,:);
plot(xval, y, xval, polyval(fliplr(w), xval))
legend('Original', 'Fitted');

enter image description here