如何用Java构建基本的神经网络?

时间:2018-07-17 17:47:28

标签: java machine-learning neural-network backpropagation

我正在尝试构建一个基本的神经网络来计算Java中的逻辑XOR函数。

该网络具有两个输入神经元,一个包含三个神经元的隐藏层和一个输出神经元。

但是经过几次迭代,输出错误变为NaN

我已经看过其他实现神经网络的实现和教程,但是找不到错误。我觉得问题出在我的向后功能上。

请帮助我了解我出了什么问题。

我的代码:

import org.ejml.simple.SimpleMatrix;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

// SimpleMatrix constructor format: SimpleMatrix(rows, cols)
//The layers are represented as a matrix with 1 row and multiple columns (row vector)
public class Network {
    private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;
    static final double LEARNING_RATE = 0.3;

    Network(List<double[]> ips, List<double[]> ops){
        hidden = new SimpleMatrix(1, 3);
        W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());
        W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);
        initWeights(W1,W2);

        for(int i=0;i<5000;i++){
            for(int j=0;j<ips.size();j++){
                train(ips.get(j), ops.get(j));
            }
        }
        System.out.println("Trained");
    }

    //Prints output matrix
    SimpleMatrix predict(double[] ip){
        SimpleMatrix bkpInputs = inputs.copy();
        SimpleMatrix bkpOutputs = outputs.copy();

        inputs = new SimpleMatrix(1, ip.length);
        inputs.setRow(0, 0, ip);

        forward();
        inputs = bkpInputs;
        outputs = bkpOutputs;

        predicted.print();
        return predicted;
    }

    void train(double[] inputs, double[] outputs){
        this.inputs = new SimpleMatrix(1, inputs.length);
        this.inputs.setRow(0, 0, inputs);
        this.outputs = new SimpleMatrix(1, outputs.length);
        this.outputs.setRow(0,0,outputs);
        this.predicted = new SimpleMatrix(1,outputs.length);

        forward();
        backward();
    }

    private void initWeights(SimpleMatrix... W){
        Random random = new Random();
        for (SimpleMatrix aW : W) {
            for (int i = 0; i < aW.numRows(); i++)
                for (int j = 0; j < aW.numCols(); j++)
                    aW.set(i, j, random.nextDouble());
        }
    }

    //Using logistic function
    double sigmoid(double x){
        return (1/(1+Math.exp(-x)));
    }

    double sigmoidPrime(double x){
        return sigmoid(x)/(1-sigmoid(x));
    }

    void forward(){
        hidden = inputs.mult(W1);
        for(int i=0;i<hidden.numCols();i++){
            double x = sigmoid(hidden.get(0,i));
            hidden.set(0,i,x);
        }
        predicted = hidden.mult(W2);
        for(int i=0;i<predicted.numRows();i++){
            for(int j=0;j<predicted.numCols();j++){
                predicted.set(i,j, sigmoid(predicted.get(i,j)));
            }
        }
    }

    void backward(){

        //Error in output
        double o_error = 0.0;
        //Error functions I tried: (1/2)( (predicted-actual) ^ 2) and (predicted - actual)
        for(int i=0;i<outputs.numCols();i++)
            o_error += (predicted.get(0, i)-outputs.get(0, i));//Math.pow(predicted.get(0, i)-outputs.get(0, i), 2)/2;
        //Checking output error
        System.out.println(o_error);

        //Output deltas
        SimpleMatrix o_deltas = new SimpleMatrix(1, outputs.numCols());
        for(int i=0;i<outputs.numCols();i++)
            o_deltas.set(0, i, o_error*sigmoidPrime(predicted.get(0, i))); 


        //Error in hidden layer and deltas
        double h_error = o_deltas.dot(W2.transpose());
        SimpleMatrix h_deltas = new SimpleMatrix(1, hidden.numCols());
        for(int i=0;i<hidden.numCols();i++)
            h_deltas.set(0, i, h_error*sigmoidPrime(hidden.get(0, i)));


        //Hidden->Output layer update
        SimpleMatrix W2_delta = W2.mult(o_deltas.transpose());
        for(int i=0;i<W2.numRows();i++){
            for(int j=0;j<W2.numCols();j++){
                W2.set(i,j, W2.get(i,j) + LEARNING_RATE*W2_delta.get(i, 0));
            }
        }

        //Input->Hidden layer update
        SimpleMatrix W1_delta = W1.mult(h_deltas.transpose());
        for(int i=0;i<W1.numRows();i++){
            for(int j=0;j<W1.numCols();j++){
                W1.set(i,j, W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));
            }
        }
    }


    public static void main(String[] args){
        double[][] ips = {
                {0,0},
                {0,1},
                {1,0},
                {1,1}
        };

        double[][] ops = {
                {0},
                {1},
                {1},
                {0}
        };

        List<double[]> ip = new ArrayList<>();
        List<double[]> op = new ArrayList<>();

        for(int i=0;i<ips.length;i++){
            ip.add(ips[i]);
            op.add(ops[i]);
        }

        double[] testip = {1,0};
        Network n = new Network(ip,op);
        n.predict(testip);
    }
}

2 个答案:

答案 0 :(得分:1)

所以可能不是引起您问题的原因,但我注意到了

W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));

更新权重时。我认为正确的公式是: The

因此您的代码应为:

W1(i,j) += LEARNING_RATE * W1_delta.get(i, 0) *  <output from the connected node>;

它可能无法解决,但值得一试!

答案 1 :(得分:0)

尝试降低学习率。当错误为NaN时,通常意味着您的成本/错误功能已爆炸。尝试使用[10^-3, 10^-5]范围内的内容。