我最近一直试图从头开始制作前馈神经网络。它已经能够学习XOR功能,虽然它比其他人描述的时间更多。
我的网络使用0.25的学习率,没有动量,激活函数是sigmoid函数。 (1 /(1 + E ^ -x))
我之前已经阅读过使用MNIST数据库的培训只需要1k个纪元左右。然而,对于像XOR这样简单的事情,我的网络需要10万个时期(几秒钟)才能获得正确的XOR答案的0.05。
例如,对于1,0,NN将输出0.96。
它最终会进行训练,但出于某种原因,它所采用的时代比描述的多了许多倍。
所以我的问题是:为什么我的网络需要这么长时间?我怀疑它与反向传播有关,虽然我不知道如何验证。
可以看到整个项目here on github.
前传:
private double evaluate(double[] input){
//reset the neuron values
for(NeuronLayer layer:layers){
for(Neuron neuron:layer.getNeurons()){
neuron.weightedSum = 0;
}
}
//set the output of the input neurons to the input
for(int i = 0; i < layers[0].getNeurons().length; i++){
layers[0].getNeurons()[i].output = input[i];
}
//cycle through all the neurons
for(int i = 0; i < layers.length; i++){
for(Neuron neuron:layers[i].getNeurons()){
if(i != 0) neuron.activationFunction(); //apply the activation function if not an input neuron
if(i != layers.length - 1){
for(Dendrite dendrite:neuron.getDendrites()){
//Increment the weightedSum of the destination neuron by the source neuron output scaled by the weight
dendrite.getEnd().weightedSum += neuron.output * dendrite.weight;
}
}
}
}
double result = layers[layers.length-1].getNeurons()[0].weightedSum; //return the output of the first output neuron.
return result;
}
获取神经元错误:
void getErrors(double result, double expectedResult){
for(int i = layers.length - 1; i > 0; i--){
NeuronLayer layer = layers[i];
for(int j = 0; j < layer.getNeurons().length; j++){
Neuron neuron = layer.getNeurons()[j];
double neuronError = 0;
if(i == layers.length - 1){
neuronError = neuron.getDerivative() * (result - expectedResult);
}
else{
neuronError = neuron.getDerivative();
double sum = 0;
for(Dendrite dendrite:neuron.getDendrites()){
sum += dendrite.weight * dendrite.getEnd().getError();
}
neuronError *= sum;
}
neuron.setError(neuronError);
}
}
}
根据错误更新权重:
void updateWeights(HashMap<Dendrite,Double> dendriteDeltaMap, double learningRate, double momentum){
for(int i = layers.length - 1; i > 0; i--){
NeuronLayer layer = layers[i];
for(Neuron neuron:layer.getNeurons()){
for(Dendrite dendrite:neuron.getInputs()){
double delta = learningRate * neuron.getError() * dendrite.getStart().getOutput();
if(dendriteDeltaMap.get(dendrite) != null){
delta += momentum * dendriteDeltaMap.get(dendrite);
}
dendriteDeltaMap.put(dendrite, delta);
dendrite.adjustWeight(-delta);
}
}
}
}
训练功能:
public void train(double[][] inputs, double[] outputs, double learningRate, double momentum, int maxIterations){
int runs = 0;
double startError = 0;
while(true){
HashMap<Dendrite,Double> dendriteDeltaMap = new HashMap<>();
double errorSum = 0;
for(int i = 0; i < inputs.length; i++){
double sum = evaluate(inputs[i]);//get sum
double result = sigmoid(sum); //calculate final result
double error = Math.pow(outputs[i]-result,2)/2; //calculate mean squared error
errorSum += error;
//System.out.println("Error: " + error);
getErrors(result, outputs[i]);
updateWeights(dendriteDeltaMap, learningRate, momentum);
}
double avgError = errorSum/inputs.length;
if(runs == 0) startError = avgError;
System.out.println("Epoch: " + runs + ", error: " + avgError);
runs++;
if(runs>=maxIterations || avgError <= Math.pow(0.03, 2)/2) break;
}
System.out.println("\nFinished!");
System.out.println("Start error: " + startError);
printWeights();
}