神经网络:反向传播不起作用(Java)

时间:2017-01-21 10:52:09

标签: java neural-network backpropagation

我必须为学校项目创建一个OCR程序,所以我开始在维基百科的帮助下创建一个Backpropagation算法。为了训练我的网络,我使用了几天前我提取的MNIST数据库,以便我拥有真实的图像文件。但现在错误总是大约237,经过一段时间的训练后,误差和重量变为NaN。我的代码出了什么问题?

A screenshot of my images folder

这是我的主课,它将训练我的网络:

package de.Marcel.NeuralNetwork;

import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

import javax.imageio.ImageIO;

public class OCR {
    public static void main(String[] args) throws IOException {
        // create network
        NeuralNetwork net = new NeuralNetwork(784, 450, 5, 0.2);

    // load Images
    File file = new File("images");

    int images= 0;
    double error = 0;
    for (File f : file.listFiles()) {
        BufferedImage image = ImageIO.read(f);

        int t = -1;
        double[] pixels = new double[784];
        for (int x = 0; x < image.getWidth(); x++) {
            for (int y = 0; y < image.getHeight(); y++) {
                t++;
                Color c = new Color(image.getRGB(x, y));

                if (c.getRed() == 0 && c.getGreen() == 0 && c.getBlue() == 0) {
                    pixels[t] = 1;
                } else if (c.getRed() == 255 && c.getGreen() == 255 && c.getBlue() == 255) {
                    pixels[t] = 0;
                }
            }
        }

        try {
            if (f.getName().startsWith("1")) {
                net.learn(pixels, new double[] { 1, 0, 0, 0, 0 });
                error += net.getError();

                images++;
            } else if (f.getName().startsWith("2")) {
                net.learn(pixels, new double[] { 0, 1, 0, 0, 0 });
                error += net.getError();

                images++;
            } else if (f.getName().startsWith("3")) {
                net.learn(pixels, new double[] { 0, 0, 1, 0, 0 });
                error += net.getError();

                images++;
            } else if (f.getName().startsWith("4")) {
                net.learn(pixels, new double[] { 0, 0, 0, 1, 0 });
                error += net.getError();

                images++;
            } else if (f.getName().startsWith("5")) {
                net.learn(pixels, new double[] { 0, 0, 0, 0, 1 });
                error += net.getError();

                images++;
            } else if (f.getName().startsWith("6")) {
                break;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    error = error / iterations;

    System.out.println("Trained images: " + images);
    System.out.println("Error: " + error);

    //save
    System.out.println("Save");
    try {
        net.saveNetwork("network.nnet");
    } catch (Exception e) {
        e.printStackTrace();
    }
}
}

......这是我的神经元课程:

package de.Marcel.NeuralNetwork;

public class Neuron {
    private double input, output;

public Neuron () {

}

public void setInput(double input) {
    this.input = input;
}

public void setOutput(double output) {
    this.output = output;
}

public double getInput() {
    return input;
}

public double getOutput() {
    return output;
}

}

...最后是我的NeuralNetwork

package de.Marcel.NeuralNetwork;

import java.io.File;
import java.io.FileWriter;
import java.util.Random;

public class NeuralNetwork {
    private Neuron[] inputNeurons, hiddenNeurons, outputNeurons;
    private double[] weightMatrix1, weightMatrix2;
    private double learningRate, error;

public NeuralNetwork(int inputCount, int hiddenCount, int outputCount, double learningRate) {
    this.learningRate = learningRate;

    // create Neurons
    // create Input
    this.inputNeurons = new Neuron[inputCount];
    for (int i = 0; i < inputCount; i++) {
        this.inputNeurons[i] = new Neuron();
    }
    // createHidden
    this.hiddenNeurons = new Neuron[hiddenCount];
    for (int i = 0; i < hiddenCount; i++) {
        this.hiddenNeurons[i] = new Neuron();
    }
    // createOutput
    this.outputNeurons = new Neuron[outputCount];
    for (int i = 0; i < outputCount; i++) {
        this.outputNeurons[i] = new Neuron();
    }

    // create weights
    Random random = new Random();
    // weightMatrix1
    this.weightMatrix1 = new double[inputCount * hiddenCount];
    for (int i = 0; i < inputCount * hiddenCount; i++) {
        this.weightMatrix1[i] = (random.nextDouble() * 2 - 1) / 0.25;
    }
    // weightMatrix2
    this.weightMatrix2 = new double[hiddenCount * outputCount];
    for (int i = 0; i < hiddenCount * outputCount; i++) {
        this.weightMatrix2[i] = (random.nextDouble() * 2 - 1) / 0.25;
    }
}

public void calculate(double[] input) throws Exception {
    // verfiy input length
    if (input.length == inputNeurons.length) {
        // forwardPropagation
        // set input array as input and output of input neurons
        for (int i = 0; i < input.length; i++) {
            inputNeurons[i].setInput(input[i]);
            inputNeurons[i].setOutput(input[i]);
        }

        // calculate output of hiddenNeurons
        for (int h = 0; h < hiddenNeurons.length; h++) {
            Neuron hNeuron = hiddenNeurons[h];
            double totalInput = 0;

            // sum up totalInput of Neuron
            for (int i = 0; i < inputNeurons.length; i++) {
                Neuron iNeuron = inputNeurons[i];
                totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i];
            }

            // set input
            hNeuron.setInput(totalInput);

            // calculate output by applying sigmoid
            double calculatedOutput = sigmoid(totalInput);

            // set output
            hNeuron.setOutput(calculatedOutput);
        }

        // calculate output of outputNeurons
        for (int o = 0; o < outputNeurons.length; o++) {
            Neuron oNeuron = outputNeurons[o];
            double totalInput = 0;

            // sum up totalInput of Neuron
            for (int h = 0; h < hiddenNeurons.length; h++) {
                Neuron hNeuron = hiddenNeurons[h];
                totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h];
            }

            // set input
            oNeuron.setInput(totalInput);

            // calculate output by applying sigmoid
            double calculatedOutput = sigmoid(totalInput);

            // set output
            oNeuron.setOutput(calculatedOutput);
        }
    } else {
        throw new Exception("[NeuralNetwork] input array is either too small or to big");
    }
}

public void learn(double[] input, double[] output) throws Exception {
    double partialOutput = 0;

    // verfiy input length
    if (input.length == inputNeurons.length) {
        // forwardPropagation
        // set input array as input and output of input neurons
        for (int i = 0; i < input.length; i++) {
            inputNeurons[i].setInput(input[i]);
            inputNeurons[i].setOutput(input[i]);
        }

        // calculate output of hiddenNeurons
        for (int h = 0; h < hiddenNeurons.length; h++) {
            Neuron hNeuron = hiddenNeurons[h];
            double totalInput = 0;

            // sum up totalInput of Neuron
            for (int i = 0; i < inputNeurons.length; i++) {
                Neuron iNeuron = inputNeurons[i];
                totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i];
            }

            // set input
            hNeuron.setInput(totalInput);

            // calculate output by applying sigmoid
            double calculatedOutput = sigmoid(totalInput);

            // set output
            hNeuron.setOutput(calculatedOutput);
        }

        // calculate output of outputNeurons
        for (int o = 0; o < outputNeurons.length; o++) {
            Neuron oNeuron = outputNeurons[o];
            double totalInput = 0;

            // sum up totalInput of Neuron
            for (int h = 0; h < hiddenNeurons.length; h++) {
                Neuron hNeuron = hiddenNeurons[h];
                totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h];
            }

            // set input
            oNeuron.setInput(totalInput);

            // calculate output by applying sigmoid
            double calculatedOutput = sigmoid(totalInput);

            // set output
            oNeuron.setOutput(calculatedOutput);
        }

        // backPropagation
        double totalError = 0;
        // calculate weights in matrix2
        for (int h = 0; h < hiddenNeurons.length; h++) {
            Neuron hNeuron = hiddenNeurons[h];

            for (int o = 0; o < outputNeurons.length; o++) {
                Neuron oNeuron = outputNeurons[o];

                // calculate weight
                double delta = learningRate * derivativeSigmoid(oNeuron.getInput())
                        * (output[o] - oNeuron.getOutput()) * hNeuron.getOutput();

                // set new weight
                weightMatrix2[h + o * hiddenNeurons.length] = weightMatrix2[h + o * hiddenNeurons.length] + delta;

                // update partial output
                partialOutput += (derivativeSigmoid(oNeuron.getInput()) * (output[o] - oNeuron.getOutput())
                        * weightMatrix2[h + o * hiddenNeurons.length]);

                //calculate error
                totalError += Math.pow((output[o] - oNeuron.getOutput()), 2);
            }
        }

        //set error
        this.error = 0.5 * totalError;

        // calculate weights in matrix1
        for (int i = 0; i < inputNeurons.length; i++) {
            Neuron iNeuron = inputNeurons[i];

            for (int h = 0; h < hiddenNeurons.length; h++) {
                Neuron hNeuron = hiddenNeurons[h];

                // calculate weight
                double delta = learningRate * derivativeSigmoid(hNeuron.getInput()) * partialOutput
                        * (iNeuron.getOutput());

                // set new weight
                weightMatrix1[i + h * inputNeurons.length] = weightMatrix1[i + h * inputNeurons.length] + delta;
            }
        }
    } else {
        throw new Exception("[NeuralNetwork] input array is either too small or to big");
    }
}

// save Network
public void saveNetwork(String fileName) throws Exception {
    File file = new File(fileName);
    FileWriter writer = new FileWriter(file);

    writer.write("weightmatrix1:");
    writer.write(System.lineSeparator());

    // write weightMatrix1
    for (double d : weightMatrix1) {
        writer.write(d + "-");
    }

    writer.write(System.lineSeparator());
    writer.write("weightmatrix2:");
    writer.write(System.lineSeparator());

    // write weightMatrix2
    for (double d : weightMatrix2) {
        writer.write(d + "-");
    }

    // save
    writer.close();
}

// sigmoid function
private double sigmoid(double input) {
    return Math.exp(input * (-1));
}

private double derivativeSigmoid(double input) {
    return sigmoid(input) * (1 - sigmoid(input));
}

public double getError() {
    return error;
}
}

1 个答案:

答案 0 :(得分:0)

看起来您的sigmoid功能不正确。它应该是1 /(1 + exp(-x))。

如果你仍然遇到NaN错误,可能是因为使用这样的功能可能是一种过度杀伤,特别是对于大数字(即,小于-10且大于10的数字)。

使用sigmoid(x)的预先计算值数组可以防止更大数据集的此问题,并且还可以帮助程序更有效地运行。

希望这有帮助!