计算不同类的协方差

时间:2012-09-02 20:17:38

标签: matlab

我正在尝试计算一组给定观察值的均值和协方差矩阵。点列表是三维数组,第一维表示类号,第二维表示观察数,第三维表示坐标号。虽然我已经能够计算平均值,但协方差似乎存在一些问题(现在,我得到一个零矩阵)。如果有人能告诉我如何纠正它,我将不胜感激。

function [ meanEst, covEst, priorProbEst, classMem ] = estimateParams( trainingSet, classList )
%estimateParams estimate all parameters for each class

numRows = size(trainingSet, 1);
numClasses = max(classList.');
%pointList = zeros(numClasses, numRows, 2);
classMem = zeros(numClasses, 1);

for rowCtr = 1:numRows
    curClass = classList(rowCtr, 1);
    classMem(curClass) = classMem(curClass) + 1;
    pointList(curClass, classMem(curClass), 1) = trainingSet(rowCtr, 1);
    pointList(curClass, classMem(curClass), 2) = trainingSet(rowCtr, 2);
end

meanEst      = zeros(numClasses, 2);
covEst       = zeros(numClasses, 2, 2);
priorProbEst = zeros(numClasses, 1);
tot          = zeros(numClasses, 2);

for classCtr = 1:numClasses
    for pointCtr = 1:classMem(classCtr)
        tot(classCtr, 1) = tot(classCtr, 1) + pointList(classCtr, pointCtr, 1);
        tot(classCtr, 2) = tot(classCtr, 2) + pointList(classCtr, pointCtr, 2);
    end
    meanEst(classCtr, 1) = tot(classCtr, 1) / classMem(classCtr);
    meanEst(classCtr, 2) = tot(classCtr, 2) / classMem(classCtr);

    covEst(classCtr) = cov(pointList(classCtr));
    priorProbEst(classCtr) = classMem(classCtr) / numRows;
end
end

感谢你花时间在这上面!

1 个答案:

答案 0 :(得分:1)

我认为你通过引入3d pointList矩阵使事情复杂化。如果感觉不错,你可以这样做,但某处有协方差估计错误。

没有理由将数据保存在这样的结构中,因为每个观察点都有类ID(即trainingSet中的每一行都有来自classList中相应行的标签) 。因此,您始终可以使用trainingSet中的逻辑索引来检索数据以估算meancov.通常,N x M = observation x variables数据矩阵用于任何估算/分类task是一个总是有帮助的约定,并且与许多MATLAB函数一致。

例如,在下面我创建一个随机训练集(NxM矩阵)和标签索引(K = Nx1列表中的4个类)并估计每个的均值和协方差,将结果分配到Kx2和{分别为{1}}矩阵。

2x2xK

作为证据,运行代码将产生与上述示例相同的nPoints = 200; % training set points nClass = 4; % number of unique classes % random training set of size nPoints x 2 (coordinates) classList = randi(nClass, nPoints, 1); trainingSet = randn(nPoints, 2); meanEst = zeros(nClass, 2); covEst = zeros(2, 2, nClass); for classID = 1:nClass meanEst(classID,:) = mean(trainingSet(classList==classID,:)); covEst(:,:,classID) = cov(trainingSet(classList==classID,:)); end 结果。