无法使SVM按预期工作

时间:2017-01-27 09:57:05

标签: c# opencv svm emgucv

我是这个主题的新手。我尝试做的是为不同参数的相同数据集训练许多SVM(我希望将来可以使用不同类型的图像签名)然后预测每个SVM并接受最常找到的类。

我试图阅读许多人关于SVM图像训练的代码,但是在我的代码中无法弄清楚我做错了什么。我试过什么,svm.Predict总是返回0。

非常感谢任何帮助或提示。

    internal class SVMClassifier
    {
        Dictionary<int, string> classIndex_name;
        List<SVM> svms;

        internal void Train(string trainFolder)
        {
            this.classIndex_name = new Dictionary<int, string>();
            Dictionary<int, List<Mat>> class_mats = getMats(trainFolder, this.classIndex_name);
            this.svms = new List<SVM>();


            Mat samples; Mat responses;
            getTrainingData(class_mats, out samples, out responses);
            svms.Add(trainSVM(samples, responses));
            svms.Add(trainSVM(samples, responses, SVM.SvmType.CSvc, SVM.SvmKernelType.Linear, 0d, 0d, 10d, TermCritType.Iter | TermCritType.Eps, 1000, 0.000001d, 0d, 0d));
            svms.Add(trainSVM(samples, responses, SVM.SvmType.CSvc, SVM.SvmKernelType.Rbf, 100d, 100d, 1d, TermCritType.Iter | TermCritType.Eps, 1000, 0.000001d, 0.1d, 0.5d));

            samples.Dispose(); responses.Dispose();

            foreach (Mat mat in class_mats.Values.SelectMany((a) => a))
                mat.Dispose();
        }
        private static Dictionary<int, List<Mat>> getMats(string trainFolder, Dictionary<int, string> classIndex_name)
        {
            Dictionary<int, List<Mat>> class_mats = new Dictionary<int, List<Mat>>();
            DirectoryInfo diTrain = new DirectoryInfo(trainFolder);
            int i = 0;
            foreach (var di in diTrain.GetDirectories())//classes are according to the directories
            {
                var dirName = di.Name;
                classIndex_name[i] = dirName;
                var fileNames = di.GetFiles().Select((a) => a.FullName).ToList();
                fileNames.Sort(new Dece.Misc.NumericSuffixFileFullNameComparer());
                class_mats[i] = fileNames.Select((a) => getMat(a, true)).ToList();
                i++;
            }
            return class_mats;
        }

        private static SVM trainSVM(Mat samples, Mat responses,
            SVM.SvmType? svm_Type = null, SVM.SvmKernelType? svm_KernelType = null, double? gamma = null, double? degree = null, double? c = null,
            TermCritType? criteriaType = null, int? criteriaMaxCount = null, double? criteriaEps = null, double? p = null, double? nu=null)
        {
            SVM svm = new SVM();
            if (svm_Type != null) svm.Type = (SVM.SvmType)svm_Type;
            if (svm_KernelType != null) svm.SetKernel((SVM.SvmKernelType)svm_KernelType);
            if (gamma != null) svm.Gamma = (double)gamma;
            if (degree != null) svm.Degree = (double)degree;
            if (c != null) svm.C = (double)c;

            if ((criteriaType != null) || (criteriaMaxCount != null) || (criteriaEps != null))
            {
                var t = new MCvTermCriteria((int)criteriaMaxCount, (double)criteriaEps);
                if (criteriaType != null) t.Type = (TermCritType)criteriaType;
                svm.TermCriteria = t;
            }


            if (p != null) svm.P = (double)p;
            if (nu != null) svm.Nu = (double)nu;

            if (!svm.Train(samples, DataLayoutType.RowSample, responses))
                throw new Exception();
            return svm;
        }

        private static void getTrainingData(Dictionary<int, List<Mat>> class_mats, out Mat samples, out Mat responses)
        {
            samples = null;
            List<int> lstResp = new List<int>();
            foreach (int cls in class_mats.Keys)
            {
                int count = 0;
                foreach (Mat mat in class_mats[cls])
                    using (var desc = mat.Reshape(0, 1))
                    {
                        if (samples == null)
                            samples = new Mat(desc.Cols, 0, desc.Depth, 1);
                        samples.PushBack(desc);
                        count += desc.Rows;
                    }
                for (int i = 0; i < count; i++)
                    lstResp.Add(cls);
            }

            //responses = new Mat(new Size(lstResp.Count, 1), DepthType.Cv32S, 1);
            //for (int i = 0; i < lstResp.Count; i++)
            //    responses.SetValue(0, i, lstResp[i]);

            responses = new Mat(new Size(1, lstResp.Count), DepthType.Cv32S, 1);
            for (int i = 0; i < lstResp.Count; i++)
                responses.SetValue(i, 0, lstResp[i]);

            if (samples.Depth != DepthType.Cv32F)
                samples.ConvertTo(samples, DepthType.Cv32F);

            CvInvoke.Normalize(samples, samples, -1, 1, NormType.MinMax);
        }

        internal void Detect(IEnumerable<string> fileNames, Action<ShapeInfo> detected)
        {
            foreach (var fn in fileNames)
                using (Mat mat = getMat(fn, false))
                {
                    {
                        using (var samples = mat.Reshape(0, 1))
                        {
                            if (samples.Depth != DepthType.Cv32F)
                                samples.ConvertTo(samples, DepthType.Cv32F);
                            CvInvoke.Normalize(samples, samples, -1, 1, NormType.MinMax);
                            foreach (var svm in this.svms)
                            {
                                Mat res = new Mat();
                                float p0 = svm.Predict(samples, res, 0);
                                float p1 = svm.Predict(samples, res, 1);
                                float p2 = svm.Predict(samples, res, 2);
                                float p3 = svm.Predict(samples, res, 3);
                                float p4 = svm.Predict(samples, res, 4);
                                float p = svm.Predict(samples, res);

                                foreach (var val in toIEnumerable(p0, p1, p2, p3, p4, p))
                                    if (val != 0f)
                                    {
                                        System.Windows.Forms.MessageBox.Show("never enters here :(");
                                    }
                            }
                        }
                    }
                }
        }

        private static Mat getMat(string fn, bool train)
        {
            var mat = new Mat(fn, ImreadModes.Grayscale);
            mat.Resize(new Size(128, 128));
            return mat;
        }
        private static IEnumerable<T> toIEnumerable<T>(params T[] items)
        {
            if (items != null)
                foreach (var item in items)
                    yield return item;
        }

    }

Mat.SetValue扩展名来自here

我希望这样的问题对于这个网站的格式是有用的。如果不是这个问题可以封闭擦除,没问题。我试图了解我们应该如何训练带有图像的svm。

1 个答案:

答案 0 :(得分:0)

是的,我是傻瓜。问题是发送一个空的Mat to Predict函数。当我尝试使用null时,我开始得到类预测。 (我仍然不明白为什么。)

分类问题让我觉得自己像个笨蛋。关于我如何使用svm的任何评论,应该做什么都是好的。

        internal void Detect(IEnumerable<string> fileNames, Action<ShapeInfo> detected)
        {
            foreach (var fn in fileNames)
                using (Mat mat = getMat(fn, false))
                {
                    {
                        using (var samples = mat.Reshape(0, 1))
                        {
                            if (samples.Depth != DepthType.Cv32F)
                                samples.ConvertTo(samples, DepthType.Cv32F);
                            CvInvoke.Normalize(samples, samples, -1, 1, NormType.MinMax);
                            foreach (var svm in this.svms)
                            {
                                float p = svm.Predict(samples, null);

                            }
                        }
                    }
                }
        }