使用球面GMM的EM算法的OpenCV实现问题

时间:2016-03-24 16:54:35

标签: c++ opencv machine-learning numerical-methods

我一直在尝试使用 OpenCV2.4.10 VS2013 创建一个小的 GMM 基于使用 EM 类中的 trainE 功能,100个2D随机生成的样本,包含2个群集)。

我想要这样做的主要原因是在 MATLAB 中验证相同的实现。对于通用协方差的情况,结果匹配。但是它们与球形协方差不匹配。

所以我用c ++编写了一个代码,它使用 EM 类显示每个样本的 loglikelihoods

#include <opencv2/opencv.hpp>

#define NumObs 100
#define Dim 2
#define numClusters 2
#define maxiter 1

int main(int argc, char** argv)
{   


    cv::Mat X = cv::Mat(NumObs, Dim, CV_64F);
    cv::Mat mean = cv::Mat(numClusters,Dim,CV_64F);
    std::vector<cv::Mat> covar;
    cv::Mat mixfrac = cv::Mat(numClusters, 1, CV_64F);
    cv::Mat logLikelihoods = cv::Mat(NumObs, 1, CV_64F);
    cv::Mat labels = cv::Mat(NumObs, 1, CV_32SC1);
    cv::Mat probs = cv::Mat(NumObs, numClusters, CV_64F);

    int i;
    mean = (cv::Mat_<double>(numClusters, Dim) << 2, 3, 4, 5);

    mixfrac = (cv::Mat_<double>(numClusters, 1) << 0.5, 0.5);

    cv::Mat temp(2, 2, CV_64F);
    temp = (cv::Mat_<double>(Dim, Dim) << 2, 0, 0, 2);
    covar.push_back(temp);
    temp = (cv::Mat_<double>(Dim, Dim) << 3, 0, 0, 3);
    covar.push_back(temp);

    X = (cv::Mat_<double>(NumObs, Dim) << 1.497295712046560, 2.389842806384741
        , -1.152388866819642, 2.324051537951996
        , 0.173091890138541, 3.534436744900519
        , 4.013635589903881, 2.635470732249428
        , 0.343404593485856, 2.710797002553179
        , 1.921946750771822, 2.230097059251492
        , 2.923105482852674, 1.968687377364361
        , 0.833763816549866, 1.621119740290876
        , 1.000974973520235, 0.501003642054062
        , 2.722581031828046, 2.179151938477484
        , 2.061657164408961, 1.870114601526011
        , 2.020337185029648, 1.822954976448330
        , 1.076363394410326, 2.640385082400831
        , 1.543478612011669, 2.241597508998520
        , 1.485372235285237, 2.893724906106137
        , -0.715927044291090, 1.369537907446789
        , 1.864588596758692, 1.688875700397673
        , -0.335577615523311, 2.183975485869516
        , -1.456483320753825, 1.493965172927608
        , -0.712160818106508, 1.912240385882839
        , 0.233365021273562, 1.954626227148290
        , 1.998344873833923, 0.343780671699337
        , 0.880159284239810, 0.926396318575558
        , 0.800239863676473, 2.841303191042556
        , 2.509835594907550, 0.767610261936824
        , -0.806015998003402, 2.440467562347095
        , 2.672846813458981, 1.899891020249070
        , 2.396696676145479, 3.621937374575406
        , 0.151374345266610, 2.569943677660735
        , 2.775466044119043, 1.781063533049695
        , 1.191366838548084, 1.702684062103669
        , -0.232621783623176, 1.631979353155736
        , -1.174537262824991, 1.738099976739696
        , -2.550748638571660, 1.816963733157464
        , 0.039014544632652, 1.808195242782888
        , -2.078943086646864, 2.084038835619226
        , 1.081089945552743, 2.576115695603576
        , -0.916439364072319, 0.708922705869511
        , 0.903297610096863, 2.791378047921652
        , -0.561947988394385, 2.656510856158910
        , 1.216211768410659, 1.219031885881044
        , 0.693337555617687, 1.359035671153648
        , 1.417702588339615, 2.313005179476850
        , 0.430424593609415, 1.132064410966174
        , -0.261023655491497, 1.981218019474763
        , 2.268819458983812, 0.391842019202139
        , 0.806386480042758, 3.199864590240754
        , -0.985989695126394, 0.893059534741550
        , -2.442654072043812, 1.689554199357084
        , 1.210751074224775, 1.499372753495019
        , -0.672567018407857, 2.970910199573346
        , -0.507381574772955, 1.967926272242394
        , 0.247969979069518, 0.947965104732303
        , 2.891713920836473, 0.861177994177589
        , 3.354602636199656, 1.942017768571913
        , -1.693061842976859, 1.760474923494942
        , 3.997953758133590, 1.297952400955861
        , 1.512000798331720, 0.853595479991537
        , 1.423345018469304, 3.082589664236521
        , -0.666695083234841, 1.125720158829322
        , 3.464220247867981, 1.038778673164228
        , 2.933373891043690, 1.628066445794965
        , 2.438787742568557, 1.882463318993010
        , 1.156530661730840, 2.009751864529454
        , 1.471382343384891, 2.299778193525170
        , 0.710205836898082, 2.982463815839716
        , 0.574655988307198, 1.897536569003219
        , -0.481180738023749, 3.243258184346350
        , 2.057634665518593, 1.158138790043992
        , 2.764066483839555, 1.749966682737925
        , 1.188742536016829, 1.830731044279766
        , 1.062643063683092, 2.464027653634071
        , 2.274832425103863, 2.976150757847411
        , -0.391925728913642, 2.459189741355192
        , 0.214293806085435, 1.173608789474463
        , 0.457619276357005, 2.215723842685201
        , -3.276595636919820, 2.027190449572088
        , 1.136813051985737, 2.411476105137650
        , 2.754286393796612, 1.525510310208986
        , 1.772550672306315, 1.678744622863505
        , 1.151056973603011, 0.934580048074344
        , 1.572929476032799, 1.934085057470423
        , 2.070875710998629, 2.577539460790769
        , -0.449497770648108, 1.397218305043158
        , 0.271960009043126, 2.461090941688480
        , 1.502451262202109, 2.258525296739507
        , 1.867786756275275, 1.401950731642538
        , -0.397193560292116, 1.959501415433482
        , -0.799417427242084, 1.323993855920221
        , 1.063400687978604, 1.690226733204147
        , 1.270024853701466, 2.875733585500125
        , -1.772562991551423, 0.986750552965217
        , 0.532253741838827, 0.619813291703453
        , 1.138823786767506, 2.845567676726764
        , 1.285833310664554, 1.354502560988730
        , 3.353413094595020, 2.320556360323725
        , -1.158616580303092, 1.591324520299605
        , 0.300520863016906, 0.893745059745460
        , 1.133378827555012, 1.534905699580405
        , 0.040235892761628, 2.368829212292840);


    cv::EM emModel(numClusters, cv::EM::COV_MAT_SPHERICAL, cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::EPS, maxiter, FLT_EPSILON));

    bool success = emModel.cv::EM::trainE(X, mean, covar, mixfrac, logLikelihoods,labels,probs);

    if (success)
        printf("Success\n");
    else
        printf("Failed\n");

    cv::Mat newmean = emModel.cv::EM::get<cv::Mat>("means");
    cv::Mat newweight = emModel.cv::EM::get<cv::Mat>("weights");
    std::vector<cv::Mat> newvar = emModel.cv::EM::get<std::vector<cv::Mat>>("covs");

    std::cout << "LLH\n" << logLikelihoods << "\n\n";
    std::cout <<"NEWMEAN\n"<< newmean<<"\n\n";
    std::cout << "NEWWEIGHTS\n" << newweight << "\n\n";
    std::cout << "NEWSIGMA1\n" << newvar[0] << "\n";
    std::cout << "NEWSIGMA2\n" << newvar[1] << "\n\n";

    scanf("%d", &i);


}

我为 maxiter 设置了值1。这样,只允许使用给定的初始参数完成第一个 Expectation 步骤,而不是 Maximization 步骤。

基本上我正在做的是使用给定的初始化检查 loglikelihoods 。现在,据我所知,使用相同的初始参数,第一个 E-step 应该给出相同的结果,无论协方差是 Generic 还是球形,因为E步骤的方程式与初始参数保持相同)。

然而,这两个结果有所不同,这是不寻常的。

球形GMM的输出

LLH
[-2.931318011752626;
  -5.437115860016975;
  -3.667420897566784;
  -3.274239241308711;
  -3.512773886210777;
  -2.906378240768077;
  -3.145533448424205;
  -3.642481763141506;
  -4.650852173801083;
  -2.975512838378632;
  -3.086635213907651;
  -3.118147350422762;
  -3.02853906103177;
  -2.976176233676393;
  -2.81071576186187;
  -5.359239062666107;
  -3.221342073874848;
  -4.364890315806004;
  -6.406040105709825;
  -4.978136266645567;
  -3.885209083574705;
  -4.577679428989389;
  -4.226199670930593;
  -3.149979577537311;
  -4.088024534030816;
  -4.880127907834185;
  -3.122802132137833;
  -2.646131497925848;
  -3.716520697504866;
  -3.224961706370468;
  -3.399575606628337;
  -4.557644951387831;
  -5.765981658468872;
  -8.37554003304065;
  -4.15407503628638;
  -7.21426852540733;
  -3.042169136121539;
  -6.295443523476872;
  -3.093842961677683;
  -4.495265457885608;
  -3.772140198103474;
  -3.935028312978259;
  -2.986138732258714;
  -4.33067591541676;
  -4.375573985871765;
  -4.519102272986751;
  -3.12565345021085;
  -6.194705958136484;
  -8.213811782800219;
  -3.538572963594731;
  -4.601005773126856;
  -4.678804407509512;
  -4.666930234789473;
  -4.083218428518389;
  -3.326859772310369;
  -6.643334152388976;
  -4.217625141712781;
  -4.034734535109596;
  -2.814646585841023;
  -5.508686183211898;
  -4.14839280421732;
  -3.381874480419066;
  -3.094404661798319;
  -3.230164061642246;
  -2.973679171864843;
  -3.195153092832868;
  -3.638373707120748;
  -4.353325915681707;
  -3.643177605071091;
  -3.243476559240886;
  -3.318103507338534;
  -3.083978027938285;
  -2.668655332006318;
  -4.331983635797021;
  -4.47589600534346;
  -3.569193241927018;
  -10.03869961506342;
  -3.064692854541199;
  -3.408620622130172;
  -3.24166993975683;
  -4.078595023654192;
  -3.122174103654394;
  -2.764010652665211;
  -4.990765765921384;
  -3.636211257765106;
  -2.980988844038545;
  -3.439623639120886;
  -4.547175745026653;
  -5.512926199458751;
  -3.467745941915222;
  -2.894697423739033;
  -7.428294826501615;
  -4.801229141025504;
  -2.958713722999831;
  -3.625296272969862;
  -3.108524035685613;
  -5.840522469845191;
  -4.677873214788853;
  -3.545697817254149;
  -3.884322355072641]

NEWMEAN
[2, 3;
  4, 5]

NEWWEIGHTS
[0.5, 0.5]

NEWSIGMA1
[2, 0;
  0, 2]
NEWSIGMA2
[3, 0;
  0, 3]

通用GMM的输出

Success
LLH
[-3.295935746117501;
  -5.790750393412513;
  -4.03429419104581;
  -3.71263347108744;
  -3.872144489126807;
  -3.274075193172064;
  -3.527637461113017;
  -3.998128057592525;
  -5.004203357427852;
  -3.356081006983058;
  -3.452688438712195;
  -3.483334605887541;
  -3.391826118348322;
  -3.33991303502579;
  -3.180885949172912;
  -5.710717495080189;
  -3.583938119252925;
  -4.719201739742172;
  -6.757228800928109;
  -5.330902409373262;
  -4.240038592584679;
  -4.935658714571719;
  -4.579970929182056;
  -3.513114639026275;
  -4.4522906425319;
  -5.234662136130952;
  -3.498568777684983;
  -3.050795422255073;
  -4.074129455698702;
  -3.60134679297809;
  -3.757286012239146;
  -4.910366860830981;
  -6.117838563989898;
  -8.727422737804261;
  -4.507867855539753;
  -7.566730910097011;
  -3.404909140030045;
  -6.645804917417026;
  -3.457241856183214;
  -4.851307959586312;
  -4.128055700550207;
  -4.289284804578601;
  -3.349349646254201;
  -4.683600556531062;
  -4.729290711442341;
  -4.879400262194479;
  -3.492881334973137;
  -6.545230968603729;
  -8.565356892970549;
  -3.895492494864605;
  -4.958752084391844;
  -5.032035254354657;
  -5.01908185049789;
  -4.453534794387545;
  -3.72062903465443;
  -6.99499565346089;
  -4.624250701398458;
  -4.391029813440906;
  -3.186754423099298;
  -5.859793331816577;
  -4.532911965106766;
  -3.759785100598588;
  -3.465822699838461;
  -3.589297731185657;
  -3.337252862476405;
  -3.55907539338088;
  -3.994110028844047;
  -4.713786842080629;
  -4.004537247619449;
  -3.619271902441214;
  -3.67643636032134;
  -3.4456342295298;
  -3.053353732793124;
  -4.687381509300202;
  -4.828411093949869;
  -3.925970540180974;
  -10.3918013626343;
  -3.426464016898089;
  -3.781854557415382;
  -3.603359371297283;
  -4.433394247345532;
  -3.483824967527195;
  -3.138004621822025;
  -5.342610735990292;
  -3.993620083098697;
  -3.344492754280415;
  -3.800540886464314;
  -4.900564570171441;
  -5.864239423435192;
  -3.824719561169336;
  -3.262235974218792;
  -7.778626718537899;
  -5.153466442034869;
  -3.324629479023928;
  -3.982041010852076;
  -3.50977175586426;
  -6.192064850807197;
  -5.030038737191805;
  -3.902365100307677;
  -4.240420295976542]

NEWMEAN
[2, 3;
  4, 5]

NEWWEIGHTS
[0.5, 0.5]

NEWSIGMA1
[2, 0;
  0, 2]
NEWSIGMA2
[3, 0;
  0, 3]

我写的 MATLAB 代码( 球形GMM )给出的结果与 Generic <匹配/ strong>预期的 OpenCV 中第一个 E-step 的情况。但是这些与 OpenCV 中的球形情况中的第一个 E-step 不匹配,我无法找到问题。

所以最后我不确定问题出在哪里。任何帮助,将不胜感激。谢谢你的时间。

0 个答案:

没有答案