EM算法代码不起作用

时间:2015-05-01 00:09:51

标签: algorithm matlab statistics

我正在尝试使用EM算法估计高斯混合的平均值,权重和协方差,但我得到了所有矢量值的“NaN”(非数字)。这是我的代码:

function [nw,ns,nu]=est(wo,so,uo,w,s,u,o,J)
% o is the data points vector (size=nx2)
% J is the number of Gaussian curve
% maximum number of Gaussians=5
% wo is the old weight vector (size=1x5) and "w" is the actual weight vector (size=1x5)
% so is the old covariance vector (size=1x5) and "s" is the actual covariance value (size=1x5)
% uo is the old mean vector (size=5x2) and "u" is the actual mean vector (size=5x2)

o=unique(o,'rows');
t=length(o);

for i=1:5
for j=1:J
    un=0;ud=0;s2n=0;
    for T=1:t
        ud=ud+probab(wo,o(T,:),uo,so,j,J);
    end
    for T=1:t
        un=un+probab(wo,o(T,:),uo,so,j,J)*(o(T,:));
    end
    for T=1:t
        s2n=s2n+probab(wo,o(T,:),uo,so,j,J)*(o(T,:)-uo(j,:))*(((o(T,:)-uo(j,:))'));
    end
    wtemp(j)=ud/t;
    utemp(j,:)=un/ud;
    stemp(j)=sqrt(s2n/(ud));
end
wo=wtemp;
uo=utemp;
so=stemp;
end
    nw=wo;
    ns=so;
    nu=uo;

“probab”功能是:

function [pro]=probab(w,o,u,s,j,J)
pro=w(j).*(1/(2*pi*s(j))).*exp((-1/(2.*(s(j).^2)))*((o(1)-u(j,1)).^2+(o(2)-u(j,2)).^2))/(gaussiandistribution(w,u,s,o(1),o(2),J));

“gaussiandistribution”功能是

function [z]=gaussiandistribution(w,u,s,X,Y,J)
z1=w(1).*(1/(2*pi*s(1))).*exp(-(((X-u(1,1)).^2)/(2.*s(1).^2)+((Y-u(1,2)).^2)/(2.*s(1).^2)));
z2=w(2).*(1/(2*pi*s(2))).*exp(-(((X-u(2,1)).^2)/(2.*s(2).^2)+((Y-u(2,2)).^2)/(2.*s(2).^2)));
z3=w(3).*(1/(2*pi*s(3))).*exp(-(((X-u(3,1)).^2)/(2.*s(3).^2)+((Y-u(3,2)).^2)/(2.*s(3).^2)));
z4=w(4).*(1/(2*pi*s(4))).*exp(-(((X-u(4,1)).^2)/(2.*s(4).^2)+((Y-u(4,2)).^2)/(2.*s(4).^2)));
z5=w(5).*(1/(2*pi*s(5))).*exp(-(((X-u(5,1)).^2)/(2.*s(5).^2)+((Y-u(5,2)).^2)/(2.*s(5).^2)));
if J==5
z=z1+z2+z3+z4+z5;
elseif J==4
z=z1+z2+z3+z4;
elseif J==3
z=z1+z2+z3;
elseif J==2
z=z1+z2;
elseif J==1
z=z1;
end

1 个答案:

答案 0 :(得分:0)

我已经运行了你的代码(在Octave中)并且我没有看到任何NaN。以下是我的价值观:

octave:13> wo
wo =

   0.20000   0.20000   0.20000   0.20000   0.20000

octave:14> so
so =

   1   1   1   1   1

octave:15> uo
uo =

   0   0
   0   0
   0   0
   0   0
   0   0

octave:16> w
w =

   0.20000   0.20000   0.20000   0.20000   0.20000

octave:17> s
s =

   1   1   1   1   1

octave:18> u
u =

   0   0
   0   0
   0   0
   0   0
   0   0

octave:19> 
octave:19> o
o =

   1.00000   2.00000
   3.00000   1.00000
  -1.00000   2.00000
   2.00000  -2.00000
  -2.00000  -1.00000
   3.00000  -2.00000
   2.00000   3.00000
   1.50000  -0.25000

这是我得到的输出:

octave:20> [nw, ns, nu] = est(wo, so, uo, w, s, u, o, 5)
nw =

   0.20000   0.20000   0.20000   0.20000   0.20000

ns =

   2.4770   2.4770   2.4770   2.4770   2.4770

nu =

   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375

我没有检查这些是否正确,但无论如何它们都不是NaN。

您对输入有什么价值?