我一直在尝试用Java实现一个基本的神经网络,但是我很难在每次测试输出接近0的情况下进行训练。
经过一段时间的努力,我假设反向传播不正确(谁能猜到?)。
我一直在浏览网站上的大量类似问题以及其他一些实现的例子,我觉得因为无法弄清楚而感到愚蠢,但我找不到任何有用的东西。
我对反向传播的实现是基于 http://www.nnwj.de/backpropagation.html
这是我到目前为止所做的事情(对于任何不良的约定和步骤缺乏单独的方法感到抱歉,我只是希望在我使其更有用之前进行测试):
NeuralNet.java
import java.util.Random;
import java.util.Scanner;
public class NeuralNet {
public Node[][] nodes;
public static void main(String[] args) {
new NeuralNet();
}
public NeuralNet() {
float[][][] trainingData = { // XOR
{{0, 0}, {0}},
{{0, 1}, {1}},
{{1, 0}, {1}},
{{1, 1}, {0}}
};
nodes = new Node[3][];
nodes[0] = new Node[trainingData[0][0].length+1]; // input
nodes[1] = new Node[2+1]; // hidden
nodes[2] = new Node[trainingData[0][1].length]; // output
// create nodes
// input
for(int i = 0; i < nodes[0].length-1; i++){
nodes[0][i] = new InputNode(1f);
}
nodes[0][nodes[0].length-1] = new InputNode(-1); // create constant bias node on input layer
// hidden
SigmoidFunction sf = new SigmoidFunction();
for(int i = 0; i < nodes[1].length-1; i++){
nodes[1][i] = new Node(sf);
}
nodes[1][nodes[1].length-1] = new InputNode(-1); // create constant bias node on hidden layer
// output
for(int i = 0; i < nodes[2].length; i++){
nodes[2][i] = new Node(sf);
}
// create synapses
Random r = new Random();
for(int layer = 0; layer < nodes.length-1; layer++){
int nextLayer = layer+1;
for(int i = 0; i < nodes[layer].length; i++){
for(int i2 = 0; i2 < nodes[nextLayer].length; i2++){
Node in = nodes[layer][i];
Node out = nodes[nextLayer][i2];
new Synapse(in, out, r.nextFloat() * (r.nextBoolean() ? 1 : -1));
}
}
}
System.out.println("Number of times to train:");
Scanner sc = new Scanner(System.in);
int iterationCount = sc.nextInt();
sc.close();
float learningRate = .25f;
for(int iter = 0; iter < iterationCount; iter++){ // train a certain number of times
float[][] thisSet = trainingData[iter%trainingData.length]; // go through the training sets in order
System.out.println("==========================");
System.out.println("inputs: ");
// set the values of the input nodes
for(int i = 0; i < nodes[0].length-1; i++){
System.out.println(thisSet[0][i]);
((InputNode)nodes[0][i]).setValue(thisSet[0][i]);
}
Node outNode = nodes[2][0]; // the output node
float result = outNode.getOutput(); // do forward propagation
System.out.println("result = " + result);
float target = thisSet[1][0];
System.out.println("expected = " + target);
float error = target - result;
System.out.println("error = " + error);
System.out.println();
// do backpropagation
for(Node n : nodes[1]){
// calculate new weights between hidden and output
float oldW1 = n.getOutputSynapse().getWeight();
float w1Result = n.getOutput();
float deltaW1 = learningRate * error * w1Result * -error * (1 + error);
float newW1 = oldW1 + deltaW1;
System.out.println("(hidden to outer synapse) " + oldW1 + " + " + deltaW1 + " -> " + newW1);
n.getOutputSynapse().setWeight(newW1);
// calculate new weights between input and hidden
for(int i = 0; i < n.getInputSynapses().size(); i++){
Synapse s = n.getInputSynapses().get(i);
float oldW = s.getWeight();
float deltaW = learningRate * error * s.getInNode().getOutput() * w1Result * (1 - w1Result);
float newW = oldW + deltaW;
System.out.println("(inner to hidden synapse) " + oldW + " + " + deltaW + " -> " + newW);
s.setWeight(newW);
}
}
}
System.out.println();
System.out.println("Finished!");
System.out.println("Results:");
for(float[][] thisSet : trainingData){
System.out.println();
System.out.println("inputs : ");
for(int i = 0; i < nodes[0].length-1; i++){
System.out.println(thisSet[0][i]);
((InputNode)nodes[0][i]).setValue(thisSet[0][i]);
}
Node outNode = nodes[2][0];
float calculated = outNode.getOutput();
System.out.println("output = " + calculated);
float target = thisSet[1][0];
System.out.println("expected = " + target);
}
}
}
Node.java
import java.util.ArrayList;
import java.util.List;
public class Node {
private Synapse outNode;
private List<Synapse> inNodes = new ArrayList<Synapse>();
private Function activationFunction;
public Node(Function activationFunction) {
this.activationFunction = activationFunction;
}
public float sumInputs(){
float sum = 0f;
for(Synapse s : inNodes) sum += s.getInNode().getOutput() * s.getWeight();
return sum;
}
[...] getters & setters
}
Synapse.java
public class Synapse {
private Node nodeIn;
private Node nodeOut;
private float weight = 0f;
public Synapse(Node in, Node out, float initialWeight) {
in.setOutputNode(this);
out.addInputNode(this);
this.nodeIn = in;
this.nodeOut = out;
weight = initialWeight;
}
[...] getters & setters
}
SigmoidFunction.java
public class SigmoidFunction extends Function{
@Override
public float applyFor(float input) {
return (float) (1f / (1f + Math.pow(Math.E, -input)));
}
@Override
public float applyForPrime(float input){
return applyFor(input) * (1f - applyFor(input));
}
}
这是一张网应该是什么样子的图: image
(对不起,如果要转储的代码很多,我不想留下任何重要的代码)
(如果我完全错了,请告诉我)
(如果我正在做stackoverflow错误让我知道(第一篇文章))