神经网络MNIST号码识别

时间:2019-04-07 21:36:49

标签: java neural-network mnist

我最近尝试用Neural Networks创建我的第一个项目,这就是我想出的。我想让它识别MNIST手写数字。问题是,当我运行此代码并使其进行约40万次训练时,我获得的测试数据的准确度约为28%。应该是那样吗? 400k太少而不能获得更好的结果,还是因为我的神经网络只能有一个隐藏层?

总结一个简短的问题,是应该看起来像那样,还是我做错了什么?下面有很多多余的代码,类似的东西,我只是想使其起作用。

一切都假设我的神经网络工作正常。

public static void main(String[] args) {

  List<Data> trainData = new ArrayList<>();
  List<Data> testData = new ArrayList<>();

  byte[] trainLabels;
  byte[] trainImages;
  byte[] testLabels;
  byte[] testImages;

  try {

     Path tempPath1 = Paths.get("res/train-labels-idx1-ubyte");
     trainLabels = Files.readAllBytes(tempPath1);
     ByteBuffer bufferLabels = ByteBuffer.wrap(trainLabels);
     int magicLabels = bufferLabels.getInt();
     int numberOfItems = bufferLabels.getInt();

     Path tempPath = Paths.get("res/train-images-idx3-ubyte");
     trainImages = Files.readAllBytes(tempPath);
     ByteBuffer bufferImages = ByteBuffer.wrap(trainImages);
     int magicImages = bufferImages.getInt();
     int numberOfImageItems = bufferImages.getInt();
     int rows = bufferImages.getInt();
     int cols = bufferImages.getInt();

     for(int i = 0; i < numberOfItems; i++) {
        int t = bufferLabels.get();
        double[] target = createTargets(t);
        double[] inputs = new double[rows*cols];
        for(int j = 0; j < inputs.length; j++) {
           inputs[j] = bufferImages.get();
           }
         Data tobj = new Data(inputs, target);
         trainData.add(tobj);
       }

      tempPath = Paths.get("res/t10k-labels-idx1-ubyte");
      testLabels = Files.readAllBytes(tempPath);
      ByteBuffer testLabelBuffer = ByteBuffer.wrap(testLabels);
      int testMagicLabels = testLabelBuffer.getInt();
      int numberOfTestLabels = testLabelBuffer.getInt();

      tempPath = Paths.get("res/t10k-images-idx3-ubyte");
      testImages = Files.readAllBytes(tempPath);
      ByteBuffer testImageBuffer = ByteBuffer.wrap(testImages);
      int testMagicImages = testImageBuffer.getInt();
      int numberOfTestImages = testImageBuffer.getInt();
      int testRows = testImageBuffer.getInt();
      int testCols = testImageBuffer.getInt();

      for(int i = 0; i < numberOfTestImages; i++) {
          double[] target = new double[]{testLabelBuffer.get()};
          double[] inputs = new double[testRows*testCols];
          for(int j = 0; j < inputs.length; j++) {
              inputs[j] = testImageBuffer.get();
             }
          Data tobj = new Data(inputs, target);
          testData.add(tobj);
         }

       NeuralNetwork neuralNetwork = new NeuralNetwork(784,64,10);

       int len = trainData.size();
       Random randomGenerator = new Random();
       for(int i = 0; i < 400000; i++) {
           int randomInt = randomGenerator.nextInt(len);
           neuralNetwork.train(trainData.get(randomInt).getInputs(), trainData.get(randomInt).getTargets());
          }

        float rightAnswers = 0;

        for(Data testObj : testData) {
           double[] output = neuralNetwork.feedforward(testObj.getInputs());
           double[] answer = testObj.getTargets(); 
         }
            System.out.println(percentage);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public static double[] createTargets(int number) {
            double[] result = new double[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
            result[number] = 1;
            return  result;

        }

1 个答案:

答案 0 :(得分:0)

如果有人感兴趣,那是我的错误。当记录所有内容时,我注意到输入像素值的范围从-255到255,并且从MNIST文档中可以看到,它们应该是0-255。最重要的是,我的输入未进行标准化,因此其中一些为0,而其他则为255。这就是我添加的内容。希望我不会错过任何东西。现在,我的准确率达到了90%。

for(int i = 0; i < numberOfTestImages; i++) {

   double[] target = new double[]{testLabelBuffer.get()& 0xFF};
   double[] inputs = new double[testRows*testCols];
   or(int j = 0; j < inputs.length; j++) {
   // Normalize input from 0-255 to 0-1
   double temp = (testImageBuffer.get() & 0xFF) / 255f;
   inputs[j] = temp;
 }
 Data tobj = new Data(inputs, target);
 testData.add(tobj);
}