改变kNN算法中k的值 - Java

时间:2016-03-09 10:42:08

标签: java algorithm csv machine-learning knn

我已经应用KNN算法对手写数字进行分类。数字最初为矢量格式8 * 8,并拉伸形成矢量1 * 64 ..

现在我的代码应用kNN算法,但只使用k = 1.我不完全确定在尝试了一些事情之后如何改变值k我不断抛出错误。如果有人能帮助我朝着正确的方向前进,我们将非常感激。可以找到训练数据集here和验证集here

ImageMatrix.java

import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;
    private int curData;
public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }
    public int getCurData() {
        return curData;
    }



}

ImageMatrixDB.java

import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results 
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum);
    }


    public static int classify(ImageMatrixDB trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                label = matrix.getClassCode();
            }
        }
        return label;
    }


    public int size() {

        return list.size(); //returns size of the list

        }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB trainingSet = new ImageMatrixDB();
        ImageMatrixDB validationSet = new ImageMatrixDB();
        trainingSet.load("cw2DataSet1.csv");
        validationSet.load("cw2DataSet2.csv"); 
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        } //285 correct
        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%"); 
        System.out.println();
    }

1 个答案:

答案 0 :(得分:2)

classify 的for循环中,您正在尝试找到最接近测试点的训练示例。您需要使用找到最接近测试数据的训练点的 K 的代码进行切换。然后你应该为每个K点调用getClassCode,并找到它们中的大多数(即最频繁的)类代码。 classify 将返回您找到的主要类代码。

您可以以任何适合您需要的方式打破关系(即,将2个最常用的类别代码分配给相同数量的训练数据)。

我对Java缺乏经验,但只是通过查看语言参考,我想出了下面的实现。

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {
    int label = 0, bestDistance = Integer.MAX_VALUE;
    int[][] distances = new int[trainingSet.size()][2];
    int i=0;

    // Place distances in an array to be sorted
    for(ImageMatrix matrix: trainingSet) {
        distances[i][0] = distance(matrix.getData(), curData);
        distances[i][1] = matrix.getClassCode();
        i++;
    }

    Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);

    // Find frequencies of each class code
    i = 0;
    Map<Integer,Integer> majorityMap;
    majorityMap = new HashMap<Integer,Integer>();
    while(i < k) {
        if( majorityMap.containsKey( distances[i][1] ) ) {
            int currentValue = majorityMap.get(distances[i][1]);
            majorityMap.put(distances[i][1], currentValue + 1);
        }
        else {
            majorityMap.put(distances[i][1], 1);
        }
        ++i;
    }

    // Find the class code with the highest frequency
    int maxVal = -1;
    for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {
        int entryVal = entry.getValue();
        if(entryVal > maxVal) {
            maxVal = entryVal;
            label = entry.getKey();
        }
    }

    return label;
}

您需要做的就是添加 K 作为参数。但请记住,上面的代码不能以特定方式处理关系。