我正在使用MATLAB计算Logistic回归成本函数,并且在测试代码时没有得到预期的输出。我的步骤在逻辑上似乎都是正确的。谁能帮我解决这个问题吗?
function [J, grad] = costFunction(theta, X, y)
m = length(y);
J = 0;
grad = zeros(size(theta));
h = sigmoid(X*theta);
total1 = -y'*log(h);
total2 = (1-y)'* (1-log(h));
total3 = total1-total2;
J = total3*1/m;
这些是测试用例:
X = [ones(3,1) magic(3)];
y = [1 0 1]';
theta = [-2 -1 1 2]';
% un-regularized
[j g] = costFunction(theta, X, y)
% or...
[j g] = costFunctionReg(theta, X, y, 0)
% results
j = 4.6832
g =
0.31722
0.87232
1.64812
2.23787
我得到j = -0.3168而不是4.6832