我最近尝试用Neural Networks创建我的第一个项目,这就是我想出的。我想让它识别MNIST手写数字。问题是,当我运行此代码并使其进行约40万次训练时,我获得的测试数据的准确度约为28%。应该是那样吗? 400k太少而不能获得更好的结果,还是因为我的神经网络只能有一个隐藏层?
总结一个简短的问题,是应该看起来像那样,还是我做错了什么?下面有很多多余的代码,类似的东西,我只是想使其起作用。
一切都假设我的神经网络工作正常。
public static void main(String[] args) {
List<Data> trainData = new ArrayList<>();
List<Data> testData = new ArrayList<>();
byte[] trainLabels;
byte[] trainImages;
byte[] testLabels;
byte[] testImages;
try {
Path tempPath1 = Paths.get("res/train-labels-idx1-ubyte");
trainLabels = Files.readAllBytes(tempPath1);
ByteBuffer bufferLabels = ByteBuffer.wrap(trainLabels);
int magicLabels = bufferLabels.getInt();
int numberOfItems = bufferLabels.getInt();
Path tempPath = Paths.get("res/train-images-idx3-ubyte");
trainImages = Files.readAllBytes(tempPath);
ByteBuffer bufferImages = ByteBuffer.wrap(trainImages);
int magicImages = bufferImages.getInt();
int numberOfImageItems = bufferImages.getInt();
int rows = bufferImages.getInt();
int cols = bufferImages.getInt();
for(int i = 0; i < numberOfItems; i++) {
int t = bufferLabels.get();
double[] target = createTargets(t);
double[] inputs = new double[rows*cols];
for(int j = 0; j < inputs.length; j++) {
inputs[j] = bufferImages.get();
}
Data tobj = new Data(inputs, target);
trainData.add(tobj);
}
tempPath = Paths.get("res/t10k-labels-idx1-ubyte");
testLabels = Files.readAllBytes(tempPath);
ByteBuffer testLabelBuffer = ByteBuffer.wrap(testLabels);
int testMagicLabels = testLabelBuffer.getInt();
int numberOfTestLabels = testLabelBuffer.getInt();
tempPath = Paths.get("res/t10k-images-idx3-ubyte");
testImages = Files.readAllBytes(tempPath);
ByteBuffer testImageBuffer = ByteBuffer.wrap(testImages);
int testMagicImages = testImageBuffer.getInt();
int numberOfTestImages = testImageBuffer.getInt();
int testRows = testImageBuffer.getInt();
int testCols = testImageBuffer.getInt();
for(int i = 0; i < numberOfTestImages; i++) {
double[] target = new double[]{testLabelBuffer.get()};
double[] inputs = new double[testRows*testCols];
for(int j = 0; j < inputs.length; j++) {
inputs[j] = testImageBuffer.get();
}
Data tobj = new Data(inputs, target);
testData.add(tobj);
}
NeuralNetwork neuralNetwork = new NeuralNetwork(784,64,10);
int len = trainData.size();
Random randomGenerator = new Random();
for(int i = 0; i < 400000; i++) {
int randomInt = randomGenerator.nextInt(len);
neuralNetwork.train(trainData.get(randomInt).getInputs(), trainData.get(randomInt).getTargets());
}
float rightAnswers = 0;
for(Data testObj : testData) {
double[] output = neuralNetwork.feedforward(testObj.getInputs());
double[] answer = testObj.getTargets();
}
System.out.println(percentage);
} catch (IOException e) {
e.printStackTrace();
}
}
public static double[] createTargets(int number) {
double[] result = new double[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
result[number] = 1;
return result;
}
答案 0 :(得分:0)
如果有人感兴趣,那是我的错误。当记录所有内容时,我注意到输入像素值的范围从-255到255,并且从MNIST文档中可以看到,它们应该是0-255。最重要的是,我的输入未进行标准化,因此其中一些为0,而其他则为255。这就是我添加的内容。希望我不会错过任何东西。现在,我的准确率达到了90%。
for(int i = 0; i < numberOfTestImages; i++) {
double[] target = new double[]{testLabelBuffer.get()& 0xFF};
double[] inputs = new double[testRows*testCols];
or(int j = 0; j < inputs.length; j++) {
// Normalize input from 0-255 to 0-1
double temp = (testImageBuffer.get() & 0xFF) / 255f;
inputs[j] = temp;
}
Data tobj = new Data(inputs, target);
testData.add(tobj);
}