计算给了我NaN

时间:2014-10-08 19:50:03

标签: matlab statistics regression logistic-regression

我正在尝试使用渐变下降来实现多项逻辑回归,但我的成本函数开始为权重分配NaN值。有人可以告诉我我做错了什么吗?

function [ cost ] = costFunctionMultiNominal( inputX,resultY,weights )
%UNTITLED8 Calculates the cost for gradient descent,assumes inputX has one
%additional feature for constant and Weights is a classes X features matrix

[rows,cols] = size(inputX);
numOfClasses = size(weights,1);
summation = 0;
for i=1:rows
    classLevelSummation = 0;
    for j=1:numOfClasses
        if resultY(i)==j
            denominatorSum = 0;
            for l=1:numOfClasses
                denominatorSum = denominatorSum + exp((inputX(i,:)*weights(l,:)')-4444);
            end
           **classLevelSummation = classLevelSummation +  log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);**
        end
    end
    summation = summation + classLevelSummation;
end
cost = summation/(-rows);
end

这是重量更新功能:

function [ Weights ] =
  getWeightsUsingGradientDescentMultiNominal(trainingX,resultY,iterMax,Alpha,weight0,lambda )

%Returns updated weights through gradient descent,weight0 are the intial randomized weights 
%   Detailed explanation goes here

rows = size(trainingX,1);
cols = size(trainingX,2)+1;
Weights = weight0;
numOfClasses = size(Weights,1);
%Adding one's to the input data for the constant terms
a = ones(rows,1);
X = [a trainingX];
%Each column corresponds to one weight, updating weights column wise:
%Also plot cst function simultaneously
tempCost = 0;
display(costFunctionMultiNominal(X,resultY,Weights));
plot(1,costFunctionMultiNominal(X,resultY,Weights),'r');
hold on;
for n=1:iterMax
    %Have to do this for all classes, i.e rows in weigths
    for j = 1:numOfClasses
        %First Calculating the Sigma over rows for all X
        summation = zeros(1,cols);
        for i=1:rows
            p = -1 * calculatePofJMultiNominal(X(i,:),Weights,j);
            if resultY(i) == j
                p = 1 + p;
            end 
            summation = summation + X(i,:)*p;

        end
       Weights(j,:) = Weights(j,:) - (Alpha)*(summation/(-rows) + lambda*Weights(j,:));
    end
    cost = costFunctionMultiNominal(X,resultY,Weights);
    display(cost);
    costDiff = tempCost - cost;
    if i~=0 && abs(costDiff)/cost <= 0.0001
        display('Breaking because of cost very less!');
        break;
    end
    tempCost = cost;

    plot(i,cost,'r');
end
hold off;
end

据我了解,NaN即将到来,因为指数方面有大量数字。我尝试从指数(-4444)减少大量数字,但无济于事。

我试过dbstop如果NaN并告诉我它停在成本函数的行上(上面代码中的粗体):

classLevelSummation = classLevelSummation +  log(exp(inputX(i,:)*weights(j,:)'-4444)/denominatorSum);
即使我删除了大常量值-4444

classLevelSummation也会变为NaN

1 个答案:

答案 0 :(得分:0)

log( exp(blah) / denominator )中,不需要取幂,然后取记录 - 这两个操作相互撤消,可能是exp()调用超出了浮点范围。如果您记得exp等于log( A / B ),则可以在没有log(A) - log(B)的情况下重新编写此内容。

exp中的简单溢出可能会给你一个Inf而不是NaN。您应该检查denominatorSum的值,因为这也是指数期限。