我正在尝试使用渐变下降来实现多项逻辑回归,但我的成本函数开始为权重分配NaN
值。有人可以告诉我我做错了什么吗?
function [ cost ] = costFunctionMultiNominal( inputX,resultY,weights )
%UNTITLED8 Calculates the cost for gradient descent,assumes inputX has one
%additional feature for constant and Weights is a classes X features matrix
[rows,cols] = size(inputX);
numOfClasses = size(weights,1);
summation = 0;
for i=1:rows
classLevelSummation = 0;
for j=1:numOfClasses
if resultY(i)==j
denominatorSum = 0;
for l=1:numOfClasses
denominatorSum = denominatorSum + exp((inputX(i,:)*weights(l,:)')-4444);
end
**classLevelSummation = classLevelSummation + log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);**
end
end
summation = summation + classLevelSummation;
end
cost = summation/(-rows);
end
这是重量更新功能:
function [ Weights ] =
getWeightsUsingGradientDescentMultiNominal(trainingX,resultY,iterMax,Alpha,weight0,lambda )
%Returns updated weights through gradient descent,weight0 are the intial randomized weights
% Detailed explanation goes here
rows = size(trainingX,1);
cols = size(trainingX,2)+1;
Weights = weight0;
numOfClasses = size(Weights,1);
%Adding one's to the input data for the constant terms
a = ones(rows,1);
X = [a trainingX];
%Each column corresponds to one weight, updating weights column wise:
%Also plot cst function simultaneously
tempCost = 0;
display(costFunctionMultiNominal(X,resultY,Weights));
plot(1,costFunctionMultiNominal(X,resultY,Weights),'r');
hold on;
for n=1:iterMax
%Have to do this for all classes, i.e rows in weigths
for j = 1:numOfClasses
%First Calculating the Sigma over rows for all X
summation = zeros(1,cols);
for i=1:rows
p = -1 * calculatePofJMultiNominal(X(i,:),Weights,j);
if resultY(i) == j
p = 1 + p;
end
summation = summation + X(i,:)*p;
end
Weights(j,:) = Weights(j,:) - (Alpha)*(summation/(-rows) + lambda*Weights(j,:));
end
cost = costFunctionMultiNominal(X,resultY,Weights);
display(cost);
costDiff = tempCost - cost;
if i~=0 && abs(costDiff)/cost <= 0.0001
display('Breaking because of cost very less!');
break;
end
tempCost = cost;
plot(i,cost,'r');
end
hold off;
end
据我了解,NaN
即将到来,因为指数方面有大量数字。我尝试从指数(-4444)减少大量数字,但无济于事。
我试过dbstop如果NaN并告诉我它停在成本函数的行上(上面代码中的粗体):
classLevelSummation = classLevelSummation + log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);
即使我删除了大常量值-4444 ,答案 0 :(得分:0)
在log( exp(blah) / denominator )
中,不需要取幂,然后取记录 - 这两个操作相互撤消,可能是exp()
调用超出了浮点范围。如果您记得exp
等于log( A / B )
,则可以在没有log(A) - log(B)
的情况下重新编写此内容。
exp
中的简单溢出可能会给你一个Inf而不是NaN。您应该检查denominatorSum
的值,因为这也是指数期限。