XOR神经网络不收敛

时间:2016-02-03 17:51:41

标签: neural-network

我在让我的XOR神经网络收敛方面遇到了问题。它有两个输入,隐藏层中有2个节点,还有一个输出节点。我认为这与我的反向传播算法有关,但我试图弄清楚问题出在哪里,但我不能。我也对所有算法进行了广泛的研究,看起来它们都是正确的。

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Random;

public class NeuralNetwork {

    public static class Perceptron {
        public ArrayList<Perceptron> inputs;
        public ArrayList<Double> inputWeight;
        public double output;
        public double error;
        private double bias = 1;
        private double biasWeight;
        public boolean activationOn = false;

        //sets up non input layers
        public Perceptron(ArrayList<Perceptron> in) {
            inputWeight = new ArrayList<Double>(in.size());
            inputs = in;

            initWeight(in.size());
        }

        //basic constructor
        public Perceptron() { }

        //generate random weights
        private void initWeight(int size) {
            Random generator = new Random();

            for(int i=0; i<size; i++) 
                inputWeight.add(i, ((generator.nextDouble())));

            biasWeight = (generator.nextDouble());
        }

        //calculate output based on current outputs of last layer
        public double calculateOutput() {
            double num = 0;

            num = bias*biasWeight;
            for(int i=0; i<inputs.size(); i++) 
                num += inputs.get(i).output * inputWeight.get(i);

            output = num;

            if(activationOn)
                output = sigmoid(output);
            else
                output = threshold(output);

            return output;
        }


        //methods used for learning

        //calculate output error
        public double calcOutputError(double expected){
            error = output * (1 - output) * (expected - output);
            return error;
        }

        //calculate node blame
        public void blame(double outError, double outWeight) {
            error = output * (1 - output) * outWeight * outError;
        }

        //adjust weights
        public void adjustWeight() {
            double alpha = .5;
            double newWeight = 0;
            for(int i=0; i<inputs.size(); i++) {
                newWeight = inputWeight.get(i) + alpha * inputs.get(i).output * error;
                inputWeight.set(i, newWeight);
            }

            //adjust bias weight
            newWeight = biasWeight + alpha * bias * error;
            biasWeight = newWeight;
            //System.out.println("Weight " + biasWeight);
        }

        //returns the sigmoid of x
        private double sigmoid(double x) {
            return (1 / ( 1 + Math.pow(Math.E, -x)));
        }

        //returns threshold of x
        private double threshold(double x) {
            if(x>=0.5)
                return 1;
            else
                return 0;
        }
    }   

    //teaches a neural network XOR
    public static void teachXOR(ArrayList<Perceptron> inputs, ArrayList<Perceptron> hidden, Perceptron output) {
        int examples[][] = { {0,0,0},
                             {1,1,0},
                             {0,1,1},
                             {1,0,1} };
        boolean examplesFix[] = {false, false, false, false};
        int layerSize = 2;
        boolean learned = false;
        boolean fixed;
        int limit = 50000;

        while(!learned && limit > 0) {  
            learned = true;
            limit--;

            //turn on using activation function
            for(int i=0; i<2; i++)
                hidden.get(i).activationOn = true;
            output.activationOn = true;

            for(int i=0; i<4; i++) {
                examplesFix[i] = false;
                //set up inputs
                for(int j=0; j<layerSize; j++)
                    inputs.get(j).output = examples[i][j];

                //calculate outputs for hidden layer
                for(int j=0; j<layerSize; j++)
                    hidden.get(j).calculateOutput();

                //calculate final output
                double outValue = output.calculateOutput();

                System.out.println("Check output " + examples[i][0] + "," + examples[i][1] + " = " + outValue);

                if(((outValue < .5 && examples[i][2] == 1) || (outValue > .5 && examples[i][2] == 0))) {
                    learned = false;
                    examplesFix[i] = true;
                }
            }           

            //turn on using activation function
            for(int i=0; i<2; i++)
                hidden.get(i).activationOn = true;
            output.activationOn = true;

            //teach the nodes that are incorrect
            if(!learned && limit >= 0) {
                for(int i=0; i<4; i++) {
                    if(examplesFix[i]) {
                        fixed = false;
                        while(!fixed) {                     
                            //System.out.println("Adjusting weight: " + examples[i][0] + "," + examples[i][1] + " --> " + examples[i][2]);
                            for(int j=0; j<layerSize; j++)
                                inputs.get(j).output = examples[i][j];

                            //calculate outputs for hidden layer
                            for(int j=0; j<layerSize; j++) 
                                hidden.get(j).calculateOutput();

                            //calculate final output
                            double outValue = output.calculateOutput();             

                            if((outValue >= .5 && examples[i][2] == 1) || (outValue < .5 && examples[i][2] == 0)) {
                                fixed = true;
                            }
                            else {
                                double outError = output.calcOutputError(examples[i][2]);
                                //blame the hidden layer nodes
                                for(int j=0; j<layerSize; j++)
                                    hidden.get(j).blame(outError, output.inputWeight.get(j));

                                //adjust weights
                                for(int j=0; j<layerSize; j++)
                                    hidden.get(j).adjustWeight();
                                output.adjustWeight();  
                            }
                        }
                    }
                }
            }
        }
        //if(limit <= 0) 
        //  System.out.println("Did not converge");//, error: " + output.error);
        //System.out.println("Done");
    }

    //runs tests for XOR, not complete
    public static void runXOR(ArrayList<Perceptron> inputs, ArrayList<Perceptron> hidden, Perceptron output) throws IOException {
        //create new file
        PrintWriter writer;
        File file = new File("Test.csv");
        if(file.exists())
            file.delete();
        file.createNewFile();
        writer = new PrintWriter(file);

        ArrayList<String> positive = new ArrayList<String>();
        ArrayList<String> negative = new ArrayList<String>();

        //turn off using activation function
        for(int i=0; i<2; i++)
            hidden.get(i).activationOn = false;
        output.activationOn = false;

        //tests 10,000 points
        for(int i=0; i<=100; i++) {
            for(int j=0; j<=100; j++) {
                inputs.get(0).output = (double)i/100;
                inputs.get(1).output = (double)j/100;

                //calculate outputs for hidden layer
                for(int k=0; k<2; k++) 
                    hidden.get(k).calculateOutput();

                //calculate final output
                double outValue = output.calculateOutput();

                //keep track of positive and negative results
                if(outValue >= .5) {
                    positive.add((double)i/100 + "," + (double)j/100 + "," + outValue);
                    //writer.println((double)i/100 + "," + (double)j/100 + ",1");
                }
                else if(outValue < .5) {
                    negative.add((double)i/100 + "," + (double)j/100 + "," + outValue);
                    //writer.println((double)i/100 + "," + (double)j/100 + ",0");
                }
            }
        }

        //write out to file
        writer.println("X,Y,Positive,X,Y,Negative");

        int i = 0;
        while(i<positive.size() && i<negative.size()) {
            writer.println(positive.get(i) + "," + negative.get(i));
            i++;
        }
        while(i<positive.size()) {
            writer.println(positive.get(i));
            i++;
        }
        while(i<negative.size()) {
            writer.println(",,,"  + negative.get(i));
            i++;
        }

        writer.close();
    }


    //used for testing
    public static void main(String[] args) throws IOException {
        int layerSize = 2;
        ArrayList<Perceptron> inputLayer;
        ArrayList<Perceptron> hiddenLayer;
        Perceptron outputLayer;

        //XOR neural network
        inputLayer = new ArrayList<Perceptron>(layerSize);
        hiddenLayer = new ArrayList<Perceptron>(layerSize);

        //for(Perceptron per : inputLayer) 
        //  per = new Perceptron();

        for(int i=0; i<layerSize; i++) 
            inputLayer.add(new Perceptron());

        for(int i=0; i<layerSize; i++) 
            hiddenLayer.add(new Perceptron(inputLayer));

        outputLayer = new Perceptron(hiddenLayer);


        teachXOR(inputLayer, hiddenLayer, outputLayer);
        runXOR(inputLayer, hiddenLayer, outputLayer);
    }
}

1 个答案:

答案 0 :(得分:0)

首先,您的代码具有非常独特的结构,并且难以调试。我会考虑从头开始编写它,结构更清晰,内部字段更少,更多实际函数返回值。

一个主要错误(可能不是唯一错误)是您在隐藏层中输出和learnOutput之间的区别。当您计算输出图层的激活时,您实际使用&#34;输出&#34;字段,虽然你应该使用learnOutput(这是唯一一个实际使用sigmoid激活的人)。

此外 - 如果您正确地重新构建代码,您可以为数值梯度测试创建单元测试,这是您在使用神经网络/其他渐变训练机器时始终所做的事情。在这种情况下,它会向您显示您的渐变不正确。