Java + OpenCV SVM培训错误

时间:2015-07-16 12:49:48

标签: java opencv

我正在尝试训练一组图像以生成训练文件,然后识别图像中的一些物体,分开设置一个可以训练的正面和负面设置。

当我使用已经训练过的文件进行测试时会出现问题,因为它会返回一个错误,指出输入文件的大小与训练样本的大小不同。但这没有意义,因为相同的图像已经被训练过了。

public class Training{

protected static final String PATH_POSITIVE = "data/positivo/";
protected static final String PATH_NEGATIVE = "data/negativo/";
protected static final String XML = "data/test.xml";
protected static final String FILE_TEST = "data/positivo/1.jpg";

static {
    System.loadLibrary( Core.NATIVE_LIBRARY_NAME );
}

protected static Mat getMat( Mat img ) {
    Mat timg = new Mat();
    Imgproc.cvtColor( img, timg, Imgproc.COLOR_BGR2GRAY );
    timg = timg.reshape( 1, timg.width() * timg.height() );
    timg.convertTo( timg, CvType.CV_32FC1 );
    return timg;
}

public static void main( String[ ] args ) {

    Mat classes = new Mat();
    Mat trainingData = new Mat();

    Mat trainingImages = new Mat();
    Mat trainingLabels = new Mat();

    CvSVM clasificador;

    for ( File file : new File( PATH_POSITIVE ).listFiles() ) {
        Mat img = Highgui.imread( file.getAbsolutePath() );
        trainingImages.push_back( getMat( img ) );
        trainingLabels.push_back( Mat.ones( new Size( 1, img.width() * img.height() ), CvType.CV_32FC1 ) );
    }

    for ( File file : new File( PATH_NEGATIVE ).listFiles() ) {
        Mat img = Highgui.imread( file.getAbsolutePath() );
        trainingImages.push_back( getMat( img ) );
        trainingLabels.push_back( Mat.zeros( new Size( 1, img.width() * img.height() ), CvType.CV_32FC1 ) );
    }

    trainingImages.copyTo( trainingData );
    trainingData.convertTo( trainingData, CvType.CV_32FC1 );
    trainingLabels.copyTo( classes );

    CvSVMParams params = new CvSVMParams();
    params.set_kernel_type( CvSVM.LINEAR );
    params.set_svm_type( CvSVM.C_SVC );
    params.set_gamma( 3 );

    clasificador = new CvSVM( trainingData, classes, new Mat(), new Mat(), params );
    clasificador.train( trainingData, classes );
    clasificador.save( XML );

    //Finished the part of the training will run the test with any file

    clasificador = new CvSVM();
    clasificador.load( new File( XML ).getAbsolutePath() );
    Mat timg = getMat( Highgui.imread( new File( FILE_TEST ).getAbsolutePath() ) );
    timg = timg.reshape( 1, timg.width() * timg.height() );
    timg.convertTo( timg, CvType.CV_32FC1 );

    //Here the error occurs
    //Exception in thread "main" CvException [org.opencv.core.CvException: cv::Exception: ..\..\..\..\opencv\modules\ml\src\inner_functions.cpp:1114: error: (-209) The sample size is different from what has been used for training in function cvPreparePredictData
    System.out.println( clasificador.predict( timg ) );

}

}

我正在使用Java 8和OpenCV 2.4.10

1 个答案:

答案 0 :(得分:2)

根据@berak给出的评论,我决定对原始代码进行一些更改。结果是以下代码,它对我有用:

public class Training{

protected static final String PATH_POSITIVE = "data/positivo/";
protected static final String PATH_NEGATIVE = "data/negativo/";
protected static final String XML = "data/test.xml";
protected static final String FILE_TEST = "data/negativo/1.jpg";

private static Mat trainingImages;
private static Mat trainingLabels;
private static Mat trainingData;
private static Mat classes;
private static CvSVM clasificador;

static {
    System.loadLibrary( Core.NATIVE_LIBRARY_NAME );
    trainingImages = new Mat();
    trainingLabels = new Mat();
    trainingData = new Mat();
    classes = new Mat();
}

public static void main( String[ ] args ) {
    trainPositive();
    trainNegative();
    train();
    test();
}

protected static void test() {
    Mat in = Highgui.imread( new File( FILE_TEST ).getAbsolutePath(), Highgui.CV_LOAD_IMAGE_GRAYSCALE );
    clasificador.load( new File( XML ).getAbsolutePath() );
    System.out.println( clasificador );
    Mat out = new Mat();
    in.convertTo( out, CvType.CV_32FC1 );
    out = out.reshape( 1, 1 );
    System.out.println( out );
    System.out.println( clasificador.predict( out ) );
}

protected static void train() {
    trainingImages.copyTo( trainingData );
    trainingData.convertTo( trainingData, CvType.CV_32FC1 );
    trainingLabels.copyTo( classes );
    CvSVMParams params = new CvSVMParams();
    params.set_kernel_type( CvSVM.LINEAR );
    clasificador = new CvSVM( trainingData, classes, new Mat(), new Mat(), params );
    clasificador.save( XML );
}

protected static void trainPositive() {
    for ( File file : new File( PATH_POSITIVE ).listFiles() ) {
        Mat img = getMat( file.getAbsolutePath() );
        trainingImages.push_back( img.reshape( 1, 1 ) );
        trainingLabels.push_back( Mat.ones( new Size( 1, 1 ), CvType.CV_32FC1 ) );
    }
}

protected static void trainNegative() {
    for ( File file : new File( PATH_NEGATIVE ).listFiles() ) {
        Mat img = getMat( file.getAbsolutePath() );
        trainingImages.push_back( img.reshape( 1, 1 ) );
        trainingLabels.push_back( Mat.zeros( new Size( 1, 1 ), CvType.CV_32FC1 ) );
    }
}

protected static Mat getMat( String path ) {
    Mat img = new Mat();
    Mat con = Highgui.imread( path, Highgui.CV_LOAD_IMAGE_GRAYSCALE );
    con.convertTo( img, CvType.CV_32FC1, 1.0 / 255.0 );
    return img;
}

}
相关问题