发布kNN算法的加载数据集 - Java

时间:2016-03-09 05:21:12

标签: java algorithm machine-learning knn

我已经应用KNN算法对手写数字进行分类。数字最初为矢量格式8 * 8,并且被拉伸以形成向量1 * 64,每组数据具有类代码0..9。

据我所知,我的代码应该在理论上有效,但这是我第一次尝试使用这种算法。我的问题源于我试图通过我的算法输入我的数据集我在我的代码中突出显示的行上引发了错误。可以找到训练数据集here和验证集here。如果它有帮助的话,我也离开了我之前的工作主要功能。

ImageMatrix.java

import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;

public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }

}

ImageMatrixDB.java

import java.util.*;
import java.io.*;

public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode);
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results 
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum); //Euclidean sqrt of the sum 
    }


    public static int classify(List<ImageMatrix> trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                curData = matrix.getData();
            }
        }
        return label;
    }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB i = new ImageMatrixDB();
        List<ImageMatrix> trainingSet = i.load("cw2DataSet1.csv"); // << ERROR HERE
        List<ImageMatrix> validationSet = i.load("cw2DataSet2.csv"); //<< ERROR HERE
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        }
        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");
    }
    //////////////////////////////////////////

    // Previous working dataset Load //
 /*   public static void main(String[] args){
        ImageMatrixDB i = new ImageMatrixDB();
        try{
            i.load("cw2DataSet1.csv"); 
            i.printResults();
        }
        catch(Exception ex){
            ex.printStackTrace();
        }
    } */

}

EDIT ///

目前的错误消息是:

Exception in thread "main" java.lang.Error: Unresolved compilation problems: 
    Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
    Type mismatch: cannot convert from ImageMatrixDB to List<ImageMatrix>
    at ImageMatrixDB.main(ImageMatrixDB.java:64)

但我在测试时遇到了其他错误。

1 个答案:

答案 0 :(得分:1)

您设计课程的方式,应按如下方式使用:

ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");

注意ImageMatrixDB的两个实例而不是一个实例,它确保将训练/验证数据加载到不同的列表中。

快速注意,在计算kNN的距离时,您应该能够使用平方距离(效率增益,sqrt是一项昂贵的操作)。所以return (int)Math.sqrt(sum);不应该要求平方根。