我正在尝试使用EM算法估计高斯混合的平均值,权重和协方差,但我得到了所有矢量值的“NaN”(非数字)。这是我的代码:
function [nw,ns,nu]=est(wo,so,uo,w,s,u,o,J)
% o is the data points vector (size=nx2)
% J is the number of Gaussian curve
% maximum number of Gaussians=5
% wo is the old weight vector (size=1x5) and "w" is the actual weight vector (size=1x5)
% so is the old covariance vector (size=1x5) and "s" is the actual covariance value (size=1x5)
% uo is the old mean vector (size=5x2) and "u" is the actual mean vector (size=5x2)
o=unique(o,'rows');
t=length(o);
for i=1:5
for j=1:J
un=0;ud=0;s2n=0;
for T=1:t
ud=ud+probab(wo,o(T,:),uo,so,j,J);
end
for T=1:t
un=un+probab(wo,o(T,:),uo,so,j,J)*(o(T,:));
end
for T=1:t
s2n=s2n+probab(wo,o(T,:),uo,so,j,J)*(o(T,:)-uo(j,:))*(((o(T,:)-uo(j,:))'));
end
wtemp(j)=ud/t;
utemp(j,:)=un/ud;
stemp(j)=sqrt(s2n/(ud));
end
wo=wtemp;
uo=utemp;
so=stemp;
end
nw=wo;
ns=so;
nu=uo;
“probab”功能是:
function [pro]=probab(w,o,u,s,j,J)
pro=w(j).*(1/(2*pi*s(j))).*exp((-1/(2.*(s(j).^2)))*((o(1)-u(j,1)).^2+(o(2)-u(j,2)).^2))/(gaussiandistribution(w,u,s,o(1),o(2),J));
“gaussiandistribution”功能是
function [z]=gaussiandistribution(w,u,s,X,Y,J)
z1=w(1).*(1/(2*pi*s(1))).*exp(-(((X-u(1,1)).^2)/(2.*s(1).^2)+((Y-u(1,2)).^2)/(2.*s(1).^2)));
z2=w(2).*(1/(2*pi*s(2))).*exp(-(((X-u(2,1)).^2)/(2.*s(2).^2)+((Y-u(2,2)).^2)/(2.*s(2).^2)));
z3=w(3).*(1/(2*pi*s(3))).*exp(-(((X-u(3,1)).^2)/(2.*s(3).^2)+((Y-u(3,2)).^2)/(2.*s(3).^2)));
z4=w(4).*(1/(2*pi*s(4))).*exp(-(((X-u(4,1)).^2)/(2.*s(4).^2)+((Y-u(4,2)).^2)/(2.*s(4).^2)));
z5=w(5).*(1/(2*pi*s(5))).*exp(-(((X-u(5,1)).^2)/(2.*s(5).^2)+((Y-u(5,2)).^2)/(2.*s(5).^2)));
if J==5
z=z1+z2+z3+z4+z5;
elseif J==4
z=z1+z2+z3+z4;
elseif J==3
z=z1+z2+z3;
elseif J==2
z=z1+z2;
elseif J==1
z=z1;
end
答案 0 :(得分:0)
我已经运行了你的代码(在Octave中)并且我没有看到任何NaN。以下是我的价值观:
octave:13> wo
wo =
0.20000 0.20000 0.20000 0.20000 0.20000
octave:14> so
so =
1 1 1 1 1
octave:15> uo
uo =
0 0
0 0
0 0
0 0
0 0
octave:16> w
w =
0.20000 0.20000 0.20000 0.20000 0.20000
octave:17> s
s =
1 1 1 1 1
octave:18> u
u =
0 0
0 0
0 0
0 0
0 0
octave:19>
octave:19> o
o =
1.00000 2.00000
3.00000 1.00000
-1.00000 2.00000
2.00000 -2.00000
-2.00000 -1.00000
3.00000 -2.00000
2.00000 3.00000
1.50000 -0.25000
这是我得到的输出:
octave:20> [nw, ns, nu] = est(wo, so, uo, w, s, u, o, 5)
nw =
0.20000 0.20000 0.20000 0.20000 0.20000
ns =
2.4770 2.4770 2.4770 2.4770 2.4770
nu =
1.18750 0.34375
1.18750 0.34375
1.18750 0.34375
1.18750 0.34375
1.18750 0.34375
我没有检查这些是否正确,但无论如何它们都不是NaN。
您对输入有什么价值?