前馈神经网络 - XOR学习非常慢

时间:2017-02-09 02:35:45

标签: java machine-learning neural-network backpropagation

我最近一直试图从头开始制作前馈神经网络。它已经能够学习XOR功能,虽然它比其他人描述的时间更多。

我的网络使用0.25的学习率,没有动量,激活函数是sigmoid函数。 (1 /(1 + E ^ -x))

我之前已经阅读过使用MNIST数据库的培训只需要1k个纪元左右。然而,对于像XOR这样简单的事情,我的网络需要10万个时期(几秒钟)才能获得正确的XOR答案的0.05。

例如,对于1,0,NN将输出0.96。

它最终会进行训练,但出于某种原因,它所采用的时代比描述的多了许多倍。

所以我的问题是:为什么我的网络需要这么长时间?我怀疑它与反向传播有关,虽然我不知道如何验证。

可以看到整个项目here on github.

前传:

private double evaluate(double[] input){
    //reset the neuron values
    for(NeuronLayer layer:layers){
        for(Neuron neuron:layer.getNeurons()){
            neuron.weightedSum = 0;
        }
    }

    //set the output of the input neurons to the input
    for(int i = 0; i < layers[0].getNeurons().length; i++){
        layers[0].getNeurons()[i].output = input[i];
    }

    //cycle through all the neurons
    for(int i = 0; i < layers.length; i++){
        for(Neuron neuron:layers[i].getNeurons()){
            if(i != 0) neuron.activationFunction(); //apply the activation function if not an input neuron
            if(i != layers.length - 1){
                for(Dendrite dendrite:neuron.getDendrites()){
                    //Increment the weightedSum of the destination neuron by the source neuron output scaled by the weight
                    dendrite.getEnd().weightedSum += neuron.output * dendrite.weight;
                }
            }
        }
    }
    double result = layers[layers.length-1].getNeurons()[0].weightedSum; //return the output of the first output neuron.
    return result;
}

获取神经元错误:

void getErrors(double result, double expectedResult){
    for(int i = layers.length - 1; i > 0; i--){
        NeuronLayer layer = layers[i];
        for(int j = 0; j < layer.getNeurons().length; j++){
            Neuron neuron = layer.getNeurons()[j];
            double neuronError = 0;
            if(i == layers.length - 1){
                neuronError = neuron.getDerivative() * (result - expectedResult);
            }
            else{
                neuronError = neuron.getDerivative();

                double sum = 0;
                for(Dendrite dendrite:neuron.getDendrites()){
                    sum += dendrite.weight * dendrite.getEnd().getError();
                }
                neuronError *= sum;
            }
            neuron.setError(neuronError);
        }
    }
}

根据错误更新权重:

void updateWeights(HashMap<Dendrite,Double> dendriteDeltaMap, double learningRate, double momentum){
    for(int i = layers.length - 1; i > 0; i--){
        NeuronLayer layer = layers[i];
        for(Neuron neuron:layer.getNeurons()){
            for(Dendrite dendrite:neuron.getInputs()){
                double delta = learningRate * neuron.getError() * dendrite.getStart().getOutput();
                if(dendriteDeltaMap.get(dendrite) != null){
                    delta += momentum * dendriteDeltaMap.get(dendrite);
                }
                dendriteDeltaMap.put(dendrite, delta);
                dendrite.adjustWeight(-delta);
            }
        }
    }
}

训练功能:

public void train(double[][] inputs, double[] outputs, double learningRate, double momentum, int maxIterations){
    int runs = 0;
    double startError = 0;
    while(true){
        HashMap<Dendrite,Double> dendriteDeltaMap = new HashMap<>();
        double errorSum = 0;
        for(int i = 0; i < inputs.length; i++){
            double sum = evaluate(inputs[i]);//get sum
            double result = sigmoid(sum); //calculate final result
            double error = Math.pow(outputs[i]-result,2)/2; //calculate mean squared error
            errorSum += error;
            //System.out.println("Error: " + error);
            getErrors(result, outputs[i]);
            updateWeights(dendriteDeltaMap, learningRate, momentum);
        }
        double avgError = errorSum/inputs.length;
        if(runs == 0) startError = avgError;
        System.out.println("Epoch: " + runs + ", error: " + avgError);
        runs++;
        if(runs>=maxIterations || avgError <= Math.pow(0.03, 2)/2) break;
    }
    System.out.println("\nFinished!");
    System.out.println("Start error: " + startError);
    printWeights();
}

0 个答案:

没有答案