我尝试使用这本书
创建一个基本的NN“制作你自己的神经网络”作者:Tariq Rashid
并使用编码列车视频:
https://www.youtube.com/watch?v=XJ7HLz9VYz0&list=PLRqwX-V7Uu6aCibgK1PTWWu9by6XFdCfh
和编码列车上的nn.js类git作为参考
https://github.com/shiffman/Neural-Network-p5/blob/master/nn.js
我在java中编写NN网络,在尝试完成单个感知器后,我尝试在播放列表中尝试在XOR上训练网络。 但出于某种原因,即使我的代码与书中的内容相似,但在视频中也是如此(除了在他使用JS的视频中)。 当我使用XOR输入的随机数据集(总共4个输入[1,0] [0,1] [0,0] [1,1])训练网络大约500000次时。
当我在训练后给出猜测时,我得到的所有4个选项的结果都接近0.5而不是1,1,0,0(测试中输入的顺序为[1,0] [0, 1] [0,0] [1,1])
这是我的训练功能:
public void train(double [] inputs, double[] target) {
//generates the Hidden layer values
this.input = Matrix.fromArrayToMatrix(inputs);
feedForward(inputs);
//convert to matrices
Matrix targets = Matrix.fromArrayToMatrix(target);
//calculate the output error
Matrix outputErrors = Matrix.subtract(targets, output);
//calculate the Gradient
Matrix outputGradient = Matrix.map(output, NeuralNetwork::sigmoidDerivative);
outputGradient = Matrix.matrixMultiplication(outputGradient, outputErrors);
outputGradient.multiply(this.learningRate);
//adjust the output layer bias
this.bias_Output.add(outputGradient);
//calculate the hidden layer weights delta
Matrix hiddenT = Matrix.Transpose(hidden);
Matrix hiddenToOutputDelta = Matrix.matrixMultiplication(outputGradient, hiddenT);
//adjust the hidden layer weights
this.weightsHiddenToOutput.add(hiddenToOutputDelta);
//calculate the hidden layer error
Matrix weightsHiddenToOutputT = Matrix.Transpose(weightsHiddenToOutput);
Matrix hiddenErrors = Matrix.matrixMultiplication(weightsHiddenToOutputT, outputErrors);
//calculate the hidden gradient
Matrix hiddenGradient = Matrix.map(this.hidden, NeuralNetwork::sigmoidDerivative);
hiddenGradient = Matrix.matrixMultiplication(hiddenGradient, hiddenErrors);
hiddenGradient.multiply(this.learningRate);
//adjust the hidden layer bias
this.bias_Hidden.add(hiddenGradient);
//calculate the input layer weights delta
Matrix inputT = Matrix.Transpose(this.input);
Matrix inputToHiddenDelta = Matrix.matrixMultiplication(hiddenGradient, inputT);
//adjust the hidden layer weights
this.weightsInputToHidden.add(inputToHiddenDelta);
}
这些是sigmoid函数:
private static double sigmoid(double x) {
return 1d / (1d+ Math.exp(-x));
}
private static double sigmoidDerivative(double x) {
return (x * (1d - x));
}
我正在使用这种方法来计算导数,因为网络已经在前馈过程中获得了sigmoid函数所以我所做的就是计算这样的导数。
这是我的猜测/前馈功能:
public double[] feedForward(double [] inputs) {
double[] guess;
//generates the Hidden layer values
input = Matrix.fromArrayToMatrix(inputs);
hidden = Matrix.matrixMultiplication(weightsInputToHidden, input);
hidden.add(bias_Hidden);
//activation function
hidden.map(NeuralNetwork::sigmoid);
//Generates the output layer values
output = Matrix.matrixMultiplication(weightsHiddenToOutput, hidden);
output.add(bias_Output);
//activation function
output.map(NeuralNetwork::sigmoid);
guess = Matrix.fromMatrixToArray(output);
return guess;
}
这是我给他的数据集的主要类:
NeuralNetwork nn = new NeuralNetwork(2,2,1);
double [] label0 = {0};
double [] label1 = {1};
Literal l1 = new Literal(label1,0,1);
Literal l2 = new Literal(label1,1,0);
Literal l3 = new Literal(label0,0,0);
Literal l4 = new Literal(label0,1,1);
Literal[] arr = {l1, l2, l3, l4};
Random random = new Random();
for(int i = 0 ; i<500000 ; i++) {
Literal l = arr[i%4];
nn.train(l.getTruthValue(), l.getLabel());
}
System.out.println(Arrays.toString(nn.feedForward(l1.getTruthValue())));
System.out.println(Arrays.toString(nn.feedForward(l2.getTruthValue())));
System.out.println(Arrays.toString(nn.feedForward(l3.getTruthValue())));
System.out.println(Arrays.toString(nn.feedForward(l4.getTruthValue())));
但由于某种原因,输出看起来像这样:
[0.47935468493879807]
[0.5041956026507048]
[0.4575246472403595]
[0.5217568912941623]
我已经尝试将其更改为减去而不是附加每个偏差和权重更新(因为你需要负梯度,虽然在书中和他们使用的视频中添加而不是减去)意味着将这4行更改为减去:
this.bias_Output.subtract(outputGradient);
this.weightsHiddenToOutput.subtract(hiddenToOutputDelta);
this.bias_Hidden.subtract(hiddenGradient);
this.weightsInputToHidden.subtract(inputToHiddenDelta);
这是我得到的两个主要产出:
[0.9999779359460259]
[0.9999935716126019]
[0.9999860145346924]
[0.999990155468117]
或
[1.7489664881918983E-5]
[6.205315404676972E-6]
[8.41530873105465E-6]
[1.1853929628341918E-5]
我很确定问题不在我创建的Matrix类中,因为我之前已经检查过它并且所有的加,减,乘,转置都运行良好。
如果有人能够查看此代码并帮助我找出问题,我将非常感激