我经历了几次安德鲁教授的机器学习课程,并使用牛顿法查看了Logistic回归的成绩单。然而,当使用梯度下降实施逻辑回归时,我面临某些问题。
生成的图形不是凸的。
我的代码如下:
我正在使用等式的矢量化实现。
%1. The below code would load the data present in your desktop to the octave memory
x=load('ex4x.dat');
y=load('ex4y.dat');
%2. Now we want to add a column x0 with all the rows as value 1 into the matrix.
%First take the length
m=length(y);
x=[ones(m,1),x];
alpha=0.1;
max_iter=100;
g=inline('1.0 ./ (1.0 + exp(-z))');
theta = zeros(size(x(1,:)))'; % the theta has to be a 3*1 matrix so that it can multiply by x that is m*3 matrix
j=zeros(max_iter,1); % j is a zero matrix that is used to store the theta cost function j(theta)
for num_iter=1:max_iter
% Now we calculate the hx or hypothetis, It is calculated here inside no. of iteration because the hupothesis has to be calculated for new theta for every iteration
z=x*theta;
h=g(z); % Here the effect of inline function we used earlier will reflect
j(num_iter)=(1/m)*(-y'* log(h) - (1 - y)'*log(1-h)) ; % This formula is the vectorized form of the cost function J(theta) This calculates the cost function
j
grad=(1/m) * x' * (h-y); % This formula is the gradient descent formula that calculates the theta value.
theta=theta - alpha .* grad; % Actual Calculation for theta
theta
end
每个说法的代码不会产生任何错误,但不会产生正确的凸图。
如果任何机构能够指出错误或分享导致问题的原因,我将感到高兴。
感谢
答案 0 :(得分:1)
您需要研究的两件事:
这是我的代码,给出了凸图
clc; clear; close all;
load q1x.dat;
load q1y.dat;
X = [ones(size(q1x, 1),1) q1x];
Y = q1y;
m = size(X,1);
n = size(X,2)-1;
%initialize
theta = zeros(n+1,1);
thetaold = ones(n+1,1);
while ( ((theta-thetaold)'*(theta-thetaold)) > 0.0000001 )
%calculate dellltheta
dellltheta = zeros(n+1,1);
for j=1:n+1,
for i=1:m,
dellltheta(j,1) = dellltheta(j,1) + [Y(i,1) - (1/(1 + exp(-theta'*X(i,:)')))]*X(i,j);
end;
end;
%calculate hessian
H = zeros(n+1, n+1);
for j=1:n+1,
for k=1:n+1,
for i=1:m,
H(j,k) = H(j,k) -[1/(1 + exp(-theta'*X(i,:)'))]*[1-(1/(1 + exp(-theta'*X(i,:)')))]*[X(i,j)]*[X(i,k)];
end;
end;
end;
thetaold = theta;
theta = theta - inv(H)*dellltheta;
(theta-thetaold)'*(theta-thetaold)
end
迭代后我得到以下错误值:
2.8553
0.6596
0.1532
0.0057
5.9152e-06
6.1469e-12
当绘制时看起来像: