Java-简单神经网络实现无法正常工作

时间:2016-03-10 17:42:12

标签: java neural-network

我正在尝试用以下方法实现神经网络:

  1. 5个输入节点(+1偏差)
  2. 1隐藏节点的隐藏层(+1偏差)
  3. 1个输出单位。
  4. 我正在使用的训练数据是5个输入单位的分离。整体误差是振荡而不是减少并达到非常高的数字。

    package neuralnetworks;
    
    import java.io.File;
    import java.io.FileNotFoundException;
    import java.math.*;
    import java.util.Random;
    import java.util.Scanner;
    
    public class NeuralNetworks {
        private double[] weightslayer1;
        private double[] weightslayer2;
        private int[][] training;
    
    
        public NeuralNetworks(int inputLayerSize, int weights1, int weights2) {
            weightslayer1 = new double[weights1];
            weightslayer2 = new double[weights2];
    
        }
    
        public static int[][] readCSV() {
            Scanner readfile = null;
            try {
                readfile = new Scanner(new File("disjunction.csv"));
            } catch (FileNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            Scanner delimit;
            int[][] train = new int[32][6];
            int lines = 0;
            while (readfile.hasNext()) {
                String line = readfile.nextLine();
                delimit = new Scanner(line);
                delimit.useDelimiter(",");
                int features = 0;
                while (delimit.hasNext() && lines > 0) {
                    train[lines - 1][features] = Integer.parseInt(delimit.next());
                    features++;
                }
    
                lines++;
            }
            return train;
    
        }
    
        public double linearcomb(double[] input, double[] weights) { //calculates the sum of the multiplication of weights and inputs
            double sigma = 0;
            for (int i = 0; i < input.length; i++) {
                sigma += (input[i] * weights[i]);
    
            }
            return sigma;
        }
    
        public double hiddenLayerOutput(int[] inputs) { //calculates the output of the hiddenlayer
    
            double[] formattedInput = new double[6]; //adds the bias unit
            formattedInput[0] = 1;
            for (int i = 1; i < formattedInput.length; i++)
                formattedInput[i] = inputs[i - 1];
            double hlOutput = linearcomb(formattedInput, weightslayer1);
            return hlOutput;
    
        }
    
        public double feedForward(int[] inputs) { //calculates the output
    
            double hlOutput = hiddenLayerOutput(inputs);
            double[] olInput = new double[2];
            olInput[0] = 1;
            olInput[1] = hlOutput;
    
            double output = linearcomb(olInput, weightslayer2);
            return output;
        }
    
        public void backprop(double predoutput, double targetout, double hidout, double learningrate, int[] input) {
    
            double outputdelta = predoutput * (1 - predoutput) * (targetout - predoutput);
            double hiddendelta = hidout * (1 - hidout) * (outputdelta * weightslayer2[1]);
    
            updateweights(learningrate, outputdelta, hiddendelta, input);
    
        }
    
        public void updateweights(double learningrate, double outputdelta, double hiddendelta, int[] input) {
            for (int i = 0; i < weightslayer1.length; i++) {
                double deltaw1 = learningrate * hiddendelta * input[i];
                weightslayer1[i] += deltaw1;
    
            }
            for (int i = 0; i < weightslayer2.length; i++) {
                double deltaw2 = learningrate * outputdelta * hiddenLayerOutput(input);
                weightslayer2[i] += deltaw2;
            }
    
        }
    
        public double test(int[] inputs) {
    
            return feedForward(inputs);
        }
    
        public void train() {
            double learningrate = 0.01;
            double output;
            double hiddenoutput;
    
            double error = 100;
            do {
                error = 0;
                for (int i = 0; i < training.length; i++) {
                    output = feedForward(training[i]);
                    error += (training[i][5] - output) * (training[i][5] - output) / 2;
                    hiddenoutput = hiddenLayerOutput(training[i]);
    
                    backprop(output, training[i][5], hiddenoutput, learningrate, training[i]);
                }
                //System.out.println(error);
    
            }while(error>1);
    
        }
    
        public static void main(String[] args) {
            NeuralNetworks nn = new NeuralNetworks(6, 6, 2);
            Random rand = new Random();
    
            nn.weightslayer2[0] = (rand.nextDouble() - 0.5);
            nn.weightslayer2[1] = (rand.nextDouble() - 0.5);
    
            for (int i = 0; i < nn.weightslayer1.length; i++)
                nn.weightslayer1[i] = (rand.nextDouble() - 0.5);
    
            nn.training = readCSV();
            /*for (int i = 0; i < nn.training.length; i++) {
                for (int j = 0; j < nn.training[i].length; j++)
                    System.out.print(nn.training[i][j] + ",");
                System.out.println();
    
            }*/
            nn.train();
    
            int[] testa = { 0, 0, 0, 0, 0 };
            System.out.println(nn.test(testa));
    
        }
    }
    

0 个答案:

没有答案