我正在尝试构建一个基本的神经网络来计算Java中的逻辑XOR
函数。
该网络具有两个输入神经元,一个包含三个神经元的隐藏层和一个输出神经元。
但是经过几次迭代,输出错误变为NaN
。
我已经看过其他实现神经网络的实现和教程,但是找不到错误。我觉得问题出在我的向后功能上。
请帮助我了解我出了什么问题。
我的代码:
import org.ejml.simple.SimpleMatrix;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
// SimpleMatrix constructor format: SimpleMatrix(rows, cols)
//The layers are represented as a matrix with 1 row and multiple columns (row vector)
public class Network {
private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;
static final double LEARNING_RATE = 0.3;
Network(List<double[]> ips, List<double[]> ops){
hidden = new SimpleMatrix(1, 3);
W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());
W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);
initWeights(W1,W2);
for(int i=0;i<5000;i++){
for(int j=0;j<ips.size();j++){
train(ips.get(j), ops.get(j));
}
}
System.out.println("Trained");
}
//Prints output matrix
SimpleMatrix predict(double[] ip){
SimpleMatrix bkpInputs = inputs.copy();
SimpleMatrix bkpOutputs = outputs.copy();
inputs = new SimpleMatrix(1, ip.length);
inputs.setRow(0, 0, ip);
forward();
inputs = bkpInputs;
outputs = bkpOutputs;
predicted.print();
return predicted;
}
void train(double[] inputs, double[] outputs){
this.inputs = new SimpleMatrix(1, inputs.length);
this.inputs.setRow(0, 0, inputs);
this.outputs = new SimpleMatrix(1, outputs.length);
this.outputs.setRow(0,0,outputs);
this.predicted = new SimpleMatrix(1,outputs.length);
forward();
backward();
}
private void initWeights(SimpleMatrix... W){
Random random = new Random();
for (SimpleMatrix aW : W) {
for (int i = 0; i < aW.numRows(); i++)
for (int j = 0; j < aW.numCols(); j++)
aW.set(i, j, random.nextDouble());
}
}
//Using logistic function
double sigmoid(double x){
return (1/(1+Math.exp(-x)));
}
double sigmoidPrime(double x){
return sigmoid(x)/(1-sigmoid(x));
}
void forward(){
hidden = inputs.mult(W1);
for(int i=0;i<hidden.numCols();i++){
double x = sigmoid(hidden.get(0,i));
hidden.set(0,i,x);
}
predicted = hidden.mult(W2);
for(int i=0;i<predicted.numRows();i++){
for(int j=0;j<predicted.numCols();j++){
predicted.set(i,j, sigmoid(predicted.get(i,j)));
}
}
}
void backward(){
//Error in output
double o_error = 0.0;
//Error functions I tried: (1/2)( (predicted-actual) ^ 2) and (predicted - actual)
for(int i=0;i<outputs.numCols();i++)
o_error += (predicted.get(0, i)-outputs.get(0, i));//Math.pow(predicted.get(0, i)-outputs.get(0, i), 2)/2;
//Checking output error
System.out.println(o_error);
//Output deltas
SimpleMatrix o_deltas = new SimpleMatrix(1, outputs.numCols());
for(int i=0;i<outputs.numCols();i++)
o_deltas.set(0, i, o_error*sigmoidPrime(predicted.get(0, i)));
//Error in hidden layer and deltas
double h_error = o_deltas.dot(W2.transpose());
SimpleMatrix h_deltas = new SimpleMatrix(1, hidden.numCols());
for(int i=0;i<hidden.numCols();i++)
h_deltas.set(0, i, h_error*sigmoidPrime(hidden.get(0, i)));
//Hidden->Output layer update
SimpleMatrix W2_delta = W2.mult(o_deltas.transpose());
for(int i=0;i<W2.numRows();i++){
for(int j=0;j<W2.numCols();j++){
W2.set(i,j, W2.get(i,j) + LEARNING_RATE*W2_delta.get(i, 0));
}
}
//Input->Hidden layer update
SimpleMatrix W1_delta = W1.mult(h_deltas.transpose());
for(int i=0;i<W1.numRows();i++){
for(int j=0;j<W1.numCols();j++){
W1.set(i,j, W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));
}
}
}
public static void main(String[] args){
double[][] ips = {
{0,0},
{0,1},
{1,0},
{1,1}
};
double[][] ops = {
{0},
{1},
{1},
{0}
};
List<double[]> ip = new ArrayList<>();
List<double[]> op = new ArrayList<>();
for(int i=0;i<ips.length;i++){
ip.add(ips[i]);
op.add(ops[i]);
}
double[] testip = {1,0};
Network n = new Network(ip,op);
n.predict(testip);
}
}