NeuralNetwork无法训练XOR(但OR可以正常工作)

时间:2019-08-18 11:38:23

标签: java machine-learning

摘要

嘿,我是机器学习的新手,我已经尝试编写自己的神经网络了。作为一个简单的任务,我选择训练“或”门“行为”,效果很好。比我想尝试XOR门的地方,问题来了。神经网络似乎无法正确训练,因为即使经过训练后,输出也与预期输出不匹配。

我试图回溯问题,但是我没有成功,也许是因为我在这一领域没有经验。 作为看如何编码反向传播的“指南”,我使用了Finn Eggers撰写的有关YT的教程。播放列表的名称为:NN-完全连接的教程。

预期:0.0、0.0-> 0.0 | 0.0、1.0-> 1.0 | 1.0,0.0-> 1.0 | 1.0,1.0-> 0.0 |。

实际舍入:0.0,0.0-> 0.29 | 0.0、1.0-> 0.64 | 1.0,0.0-> 0.53 | 1.0,1.0-> 0.59 |。

最后一个结果并不是真正的预期结果(或可以接受,因为我不能仅仅看它是否大于0.5或小于1或0)。

为您提供帮助。

完整代码

public class NeuralNetwork 
{
    public static void main(String[] args)
    {
        NeuralNetwork nn = new NeuralNetwork(new int[]{2,2,1});
        //training data for the nn
        double[] in1 = new double[]{0.0, 0.0};
        double[] in2 = new double[]{0.0, 1.0};
        double[] in3 = new double[]{1.0, 0.0};
        double[] in4 = new double[]{1.0, 1.0};

        double[] out1 = new double[]{0.0};
        double[] out2 = new double[]{1.0};
        double[] out3 = new double[]{1.0};
        double[] out4 = new double[]{0.0};

        int set = 0;
        //trains the nn with data from above
        for(int i = 0; i < 5000; i++)
        {
            if(set == 0)
            {
                nn.train(in1, out1, 0.3);
            }
            else if(set == 1)
            {
                nn.train(in2, out2, 0.3);
            }
            else if(set == 2)
            {
                nn.train(in3, out3, 0.3);
            }
            else
            {
                nn.train(in4, out4, 0.3);
            }
            set = (set + 1) % 4;
        }

        //test runs the data to see what output the nn is producing
        nn.run(in1);
        nn.run(in2);
        nn.run(in3);
        nn.run(in4);
    }

    public int[] layers;
    //input layer
    public Neuron[] input;
    //output layer
    public Neuron[] output;
    //neural network
    public Neuron[][] nn;

    public NeuralNetwork(int[] layers)
    {
        this.layers = layers;

        //creates the neural network
        this.nn = new Neuron[layers.length][];
        for(int i = 0; i < this.nn.length; i++)
        {
            this.nn[i] = new Neuron[layers[i]];
            for(int k = 0; k < layers[i]; k++)
            {
                //creates a neruon and gives it an input if its not in the input layer of the nn
                this.nn[i][k] = (i == 0)? new Neuron(): new Neuron(this.nn[i - 1]);
            }
        }

        //sets the input layer
        this.input = this.nn[0];
        //sets the output layer
        this.output = this.nn[this.nn.length - 1];
    }

    //calculates the output of the nn but doesnt train the nn
    public void run(double[] in)
    {
        calculate(in);
        double[] result = new double[this.output.length];
        for(int i = 0; i < this.output.length; i++)
        {
            result[i] = this.output[i].output;
        }
        System.out.println(Arrays.toString(result));
    }

    //trains the nn
    public void train(double[] in, double[] out, double rate)
    {
        calculate(in);
        backPropagate(out);
        updateWeights(rate);
    }

    //calculates the output of the nn
    public void calculate(double[] in)
    {
        //sets the input of the neural network
        for(int i = 0; i < in.length; i++)
        {
            this.nn[0][i].output = in[i];
        }

        //runs the neural network
        for(int i = 1; i < this.nn.length; i++)
        {
            for(int k = 0; k < this.nn[i].length; k++)
            {
                this.nn[i][k].calculateOutput();
            }
        }
    }

    //backpropagates the error (atleast it should)
    public void backPropagate(double[] target)
    {
        //calculates the error of the output layer
        for(int i = 0; i < this.output.length; i++)
        {
            this.output[i].error = (this.output[i].output - target[i]) * this.output[i].derivative;
        }
        //calculates the error of the hidden layer
        for(int i = this.nn.length - 2; i > 0; i--)
        {
            for(int k = 0; k < this.nn[i].length; k++)
            {
                double sum = 0;
                for(int a = 0; a < this.nn[i + 1].length; a++)
                {
                    sum += this.nn[i + 1][a].weights[k] * this.nn[i + 1][a].error;
                }
                this.nn[i][k].error = sum * this.nn[i][k].derivative;
            }
        }
    }

    //updates the weights of the hidden layer(s) and the output layer
    public void updateWeights(double rate)
    {
        for(int i = 1; i < this.nn.length; i++)
        {
            for(int k = 0; k < this.nn[i].length; k++)
            {
                double delta = -rate * this.nn[i][k].error;

                this.nn[i][k].bias += delta;

                for(int a = 0; a < this.nn[i - 1].length; a++)
                {
                    this.nn[i][k].weights[a] += delta * this.nn[i - 1][a].output;
                }
            }
        }
    }
}

public class Neuron 
{
    public double error;
    public double output;
    public double bias;
    public double derivative;
    public double[] weights;
    public Neuron[] inputs;

    //constructor for neurons of the hidden / output layer
    public Neuron(Neuron[] inputs)
    {
        //sets the inputs of the neuron
        this.inputs = inputs;
        //creates random weights
        this.weights = new double[inputs.length];
        for(int i = 0; i < this.weights.length; i++)
        {
            this.weights[i] = ThreadLocalRandom.current().nextDouble();
        }
        //sets a random bias
        this.bias = ThreadLocalRandom.current().nextDouble();
    }

    //empty constructor (for the input layer)
    public Neuron()
    {

    }

    public void calculateOutput()
    {
        appyActivationFunction(calculateSum());
        calculateDerivitive();
    }

    public double calculateSum()
    {
        double sum = this.bias;
        for(int i = 0; i < this.inputs.length; i++)
        {
            sum += this.weights[i] * this.inputs[i].output;
        }
        return sum;
    }

    public void appyActivationFunction(double in)
    {
        //applys the sigmoid function
        this.output = 1.0 / (1.0 + Math.exp(-in));
    }

    //calculates the derivative
    public void calculateDerivitive()
    {
        this.derivative = this.output * (1.0 - this.output);
    }
}

0 个答案:

没有答案