实现"梯度下降算法"在Matlab中

时间:2015-04-11 21:05:27

标签: algorithm matlab

我正在Machine Learning课程中解决编程作业。我将在其中实施Gradient Descent Algorithm,如下所示

enter image description here

我在Matlab中使用以下代码

data = load('ex1data1.txt');
% text file conatins 2 values in each row separated by commas
X = [ones(m, 1), data(:,1)];
theta = zeros(2, 1);
iterations = 1500;
alpha = 0.01;

function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)
m = length(y); % number of training examples
J_history = zeros(num_iters, 1);
for iter = 1:num_iters
   k=1:m;
   j1=(1/m)*sum((theta(1)+theta(2).*X(k,2))-y(k))
   j2=((1/m)*sum((theta(1)+theta(2).*X(k,2))-y(k)))*(X(k,2))
   theta(1)=theta(1)-alpha*(j1);
   theta(2)=theta(2)-alpha*(j2);
   J_history(iter) = computeCost(X, y, theta);
end
end

theta = gradientDescent(X, y, theta, alpha, iterations);

在运行上面的代码时,我收到此错误消息

enter image description here

从错误信息中可以清楚地看出,以下表达式的结果为

((1/m)*sum((theta(1)+theta(2).*X(k,2))-y(k)))*(X(k,2))

是一个向量,我们试图将其保存在标量变量j2中。我认为X(k,2)正在创建一个问题,我已经将它用作向量X的索引,以从第k行和第2列获取值。但另一方面,整个矢量越来越多,请建议我如何解决它。

1 个答案:

答案 0 :(得分:2)

您应该学会阅读错误消息,并按照那里的潜在客户进行操作:

  • 如错误消息中所述,行theta(2)=theta(2)-alpha*(j2);的左侧和右侧的元素数量不同,因此请尝试找出它中的哪一个是。标准技巧是对行上表达式的所有不同术语执行disp(size(...)),然后检查所有内容是否具有您期望的大小。

  • 进一步推理:theta(2)alpha似乎是标量,因此j2可能是非标量。

  • 查看j2的定义,似乎sum(...)是标量,而最终(X(k,2))是大小(m,1)的向量,所以j2也是大小(m,1),而它应该是标量。错误可能是您需要将X(k,2)部分包含在总和中,以便最终结果是标量。

其他一些观察结果:

  • 您创建X所有的第一列,而后来只使用第二列。更容易使用x = data(:,1)并使用它。

  • 您执行m=length(y); k=1:m;,然后多次使用y(k)。更容易使用y本身...

您可能需要执行类似

的操作
theta(1)=theta(1) - alpha / m * sum(whatever - y);
theta(2)=theta(2) - alpha / m * sum((whatever - y) .* x);

你需要自己弄清楚其余的......