我已经应用KNN算法对手写数字进行分类。数字最初为矢量格式8 * 8,并拉伸形成矢量1 * 64 ..
现在我的代码应用kNN算法,但只使用k = 1.我不完全确定在尝试了一些事情之后如何改变值k我不断抛出错误。如果有人能帮助我朝着正确的方向前进,我们将非常感激。可以找到训练数据集here和验证集here。
ImageMatrix.java
import java.util.*;
public class ImageMatrix {
private int[] data;
private int classCode;
private int curData;
public ImageMatrix(int[] data, int classCode) {
assert data.length == 64; //maximum array length of 64
this.data = data;
this.classCode = classCode;
}
public String toString() {
return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
}
public int[] getData() {
return data;
}
public int getClassCode() {
return classCode;
}
public int getCurData() {
return curData;
}
}
ImageMatrixDB.java
import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
private List<ImageMatrix> list = new ArrayList<ImageMatrix>();
public ImageMatrixDB load(String f) throws IOException {
try (
FileReader fr = new FileReader(f);
BufferedReader br = new BufferedReader(fr)) {
String line = null;
while((line = br.readLine()) != null) {
int lastComma = line.lastIndexOf(',');
int classCode = Integer.parseInt(line.substring(1 + lastComma));
int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
.mapToInt(Integer::parseInt)
.toArray();
ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
list.add(matrix);
}
}
return this;
}
public void printResults(){ //output results
for(ImageMatrix matrix: list){
System.out.println(matrix);
}
}
public Iterator<ImageMatrix> iterator() {
return this.list.iterator();
}
/// kNN implementation ///
public static int distance(int[] a, int[] b) {
int sum = 0;
for(int i = 0; i < a.length; i++) {
sum += (a[i] - b[i]) * (a[i] - b[i]);
}
return (int)Math.sqrt(sum);
}
public static int classify(ImageMatrixDB trainingSet, int[] curData) {
int label = 0, bestDistance = Integer.MAX_VALUE;
for(ImageMatrix matrix: trainingSet) {
int dist = distance(matrix.getData(), curData);
if(dist < bestDistance) {
bestDistance = dist;
label = matrix.getClassCode();
}
}
return label;
}
public int size() {
return list.size(); //returns size of the list
}
public static void main(String[] argv) throws IOException {
ImageMatrixDB trainingSet = new ImageMatrixDB();
ImageMatrixDB validationSet = new ImageMatrixDB();
trainingSet.load("cw2DataSet1.csv");
validationSet.load("cw2DataSet2.csv");
int numCorrect = 0;
for(ImageMatrix matrix:validationSet) {
if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
} //285 correct
System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%");
System.out.println();
}
答案 0 :(得分:2)
在 classify 的for循环中,您正在尝试找到最接近测试点的训练示例。您需要使用找到最接近测试数据的训练点的 K 的代码进行切换。然后你应该为每个K点调用getClassCode,并找到它们中的大多数(即最频繁的)类代码。 classify 将返回您找到的主要类代码。
您可以以任何适合您需要的方式打破关系(即,将2个最常用的类别代码分配给相同数量的训练数据)。
我对Java缺乏经验,但只是通过查看语言参考,我想出了下面的实现。
public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {
int label = 0, bestDistance = Integer.MAX_VALUE;
int[][] distances = new int[trainingSet.size()][2];
int i=0;
// Place distances in an array to be sorted
for(ImageMatrix matrix: trainingSet) {
distances[i][0] = distance(matrix.getData(), curData);
distances[i][1] = matrix.getClassCode();
i++;
}
Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);
// Find frequencies of each class code
i = 0;
Map<Integer,Integer> majorityMap;
majorityMap = new HashMap<Integer,Integer>();
while(i < k) {
if( majorityMap.containsKey( distances[i][1] ) ) {
int currentValue = majorityMap.get(distances[i][1]);
majorityMap.put(distances[i][1], currentValue + 1);
}
else {
majorityMap.put(distances[i][1], 1);
}
++i;
}
// Find the class code with the highest frequency
int maxVal = -1;
for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {
int entryVal = entry.getValue();
if(entryVal > maxVal) {
maxVal = entryVal;
label = entry.getKey();
}
}
return label;
}
您需要做的就是添加 K 作为参数。但请记住,上面的代码不能以特定方式处理关系。