神经网络Java实施问题

时间:2017-07-06 23:36:33

标签: java neural-network

我一直在尝试用Java实现一个基本的神经网络,但是我很难在每次测试输出接近0的情况下进行训练。 经过一段时间的努力,我假设反向传播不正确(谁能猜到?)。
我一直在浏览网站上的大量类似问题以及其他一些实现的例子,我觉得因为无法弄清楚而感到愚蠢,但我找不到任何有用的东西。

我对反向传播的实现是基于 http://www.nnwj.de/backpropagation.html

这是我到目前为止所做的事情(对于任何不良的约定和步骤缺乏单独的方法感到抱歉,我只是希望在我使其更有用之前进行测试):

NeuralNet.java

import java.util.Random;
import java.util.Scanner;

public class NeuralNet {

    public Node[][] nodes;

    public static void main(String[] args) {
        new NeuralNet();
    }

    public NeuralNet() {
        float[][][] trainingData = { // XOR
                {{0, 0}, {0}},
                {{0, 1}, {1}},
                {{1, 0}, {1}},
                {{1, 1}, {0}}
        };

        nodes = new Node[3][];
        nodes[0] = new Node[trainingData[0][0].length+1]; // input
        nodes[1] = new Node[2+1];                         // hidden
        nodes[2] = new Node[trainingData[0][1].length];   // output

        // create nodes

        // input
        for(int i = 0; i < nodes[0].length-1; i++){
            nodes[0][i] = new InputNode(1f);
        }
        nodes[0][nodes[0].length-1] = new InputNode(-1); // create constant bias node on input layer

        // hidden
        SigmoidFunction sf = new SigmoidFunction();
        for(int i = 0; i < nodes[1].length-1; i++){
            nodes[1][i] = new Node(sf);
        }
        nodes[1][nodes[1].length-1] = new InputNode(-1); // create constant bias node on hidden layer

        // output
        for(int i = 0; i < nodes[2].length; i++){
            nodes[2][i] = new Node(sf);
        }

        // create synapses

        Random r = new Random();

        for(int layer = 0; layer < nodes.length-1; layer++){
            int nextLayer = layer+1;
            for(int i = 0; i < nodes[layer].length; i++){
                for(int i2 = 0; i2 < nodes[nextLayer].length; i2++){
                    Node in = nodes[layer][i];
                    Node out = nodes[nextLayer][i2];

                    new Synapse(in, out, r.nextFloat() * (r.nextBoolean() ? 1 : -1));
                }
            }
        }

        System.out.println("Number of times to train:");
        Scanner sc = new Scanner(System.in);
        int iterationCount = sc.nextInt();
        sc.close();

        float learningRate = .25f;

        for(int iter = 0; iter < iterationCount; iter++){ // train a certain number of times

            float[][] thisSet = trainingData[iter%trainingData.length]; // go through the training sets in order

            System.out.println("==========================");
            System.out.println("inputs: ");

            // set the values of the input nodes

            for(int i = 0; i < nodes[0].length-1; i++){
                System.out.println(thisSet[0][i]);
                ((InputNode)nodes[0][i]).setValue(thisSet[0][i]);
            }

            Node outNode = nodes[2][0]; // the output node

            float result = outNode.getOutput(); // do forward propagation
            System.out.println("result = " + result);

            float target = thisSet[1][0];
            System.out.println("expected = " + target);

            float error = target - result;
            System.out.println("error = " + error);

            System.out.println();

            // do backpropagation

            for(Node n : nodes[1]){

                // calculate new weights between hidden and output

                float oldW1 = n.getOutputSynapse().getWeight();

                float w1Result = n.getOutput();

                float deltaW1 = learningRate * error * w1Result * -error * (1 + error);
                float newW1 = oldW1 + deltaW1;

                System.out.println("(hidden to outer synapse) " + oldW1 + " + " + deltaW1 + " -> " + newW1);

                n.getOutputSynapse().setWeight(newW1);

                // calculate new weights between input and hidden

                for(int i = 0; i < n.getInputSynapses().size(); i++){
                    Synapse s = n.getInputSynapses().get(i);

                    float oldW = s.getWeight();

                    float deltaW = learningRate * error * s.getInNode().getOutput() * w1Result * (1 - w1Result);

                    float newW = oldW + deltaW;

                    System.out.println("(inner to hidden synapse) " + oldW + " + " + deltaW + " -> " + newW);
                    s.setWeight(newW);
                }
            }
        }

        System.out.println();
        System.out.println("Finished!");
        System.out.println("Results:");

        for(float[][] thisSet : trainingData){
            System.out.println();

            System.out.println("inputs : ");

            for(int i = 0; i < nodes[0].length-1; i++){
                System.out.println(thisSet[0][i]);
                ((InputNode)nodes[0][i]).setValue(thisSet[0][i]);
            }

            Node outNode = nodes[2][0];

            float calculated = outNode.getOutput();

            System.out.println("output = " + calculated);

            float target = thisSet[1][0];

            System.out.println("expected = " + target);
        }

    }

}

Node.java

import java.util.ArrayList;
import java.util.List;

public class Node {

    private Synapse outNode;
    private List<Synapse> inNodes = new ArrayList<Synapse>();
    private Function activationFunction;

    public Node(Function activationFunction) {
        this.activationFunction = activationFunction;
    }

    public float sumInputs(){
        float sum = 0f;
        for(Synapse s : inNodes) sum += s.getInNode().getOutput() * s.getWeight();
        return sum;
    }

    [...] getters & setters

}

Synapse.java

public class Synapse {

    private Node nodeIn;
    private Node nodeOut;

    private float weight = 0f;

    public Synapse(Node in, Node out, float initialWeight) {
        in.setOutputNode(this);
        out.addInputNode(this);

        this.nodeIn = in;
        this.nodeOut = out;
        weight = initialWeight;
    }

    [...] getters & setters

}

SigmoidFunction.java

public class SigmoidFunction extends Function{

    @Override
    public float applyFor(float input) {
        return (float) (1f / (1f + Math.pow(Math.E, -input)));
    }

    @Override
    public float applyForPrime(float input){
        return applyFor(input) * (1f - applyFor(input));
    }

}

这是一张网应该是什么样子的图: image

(对不起,如果要转储的代码很多,我不想留下任何重要的代码)
(如果我完全错了,请告诉我) (如果我正在做stackoverflow错误让我知道(第一篇文章))

0 个答案:

没有答案