我正在训练我的网络,但不会出现0.9212499999999999错误。并将相同的输出带到所有输入。
JAVAML jml = new JAVAML();
// create a neural network, without using a factory
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(null,true,3));
network.addLayer(new BasicLayer(new ActivationReLU(),true,50));
network.addLayer(new BasicLayer(new ActivationSigmoid(),false,2));
network.getStructure().finalizeStructure();
network.reset();
// create training data
MLDataSet dataSet = new BasicMLDataSet(jml.getX(), jml.getY());
// train the neural network
final ResilientPropagation train = new ResilientPropagation(network, dataSet);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
epoch++;
} while(train.getError() > 0.01 & epoch < 3500);
train.finishTraining();
// test the neural network
System.out.println("Neural Network Results:");
for(MLDataPair pair: dataSet ) {
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0) + ", " + pair.getInput().getData(1)+ ", " + pair.getInput().getData(2)
+ ", actual=" + output.getData(0) +", "+ output.getData(1)
+ ", ideal=" + pair.getIdeal().getData(0)+", "+ pair.getIdeal().getData(1));
}
Encog.getInstance().shutdown();
数据集示例首先输入3个,最后输出2个:
1.0 68.0 256.0 1.0 2.5
2.0 982.0 102.0 2.5 0.5
3.0 821.0 354.0 2.5 1.0
4.0 772.0 204.0 2.5 1.0
5.0 115.0 235.0 1.0 2.5
6.0 824.0 179.0 2.5 0.5
7.0 775.0 258.0 2.5 1.0
传播错误
Epoch #14 Error:0.9212500000000001
Epoch #15 Error:0.9212499999999999
Epoch #16 Error:0.9212499999999999
Epoch #17 Error:0.9212499999999999
Epoch #18 Error:0.9212499999999999
Epoch #19 Error:0.9212500000000001
Epoch #20 Error:0.92125
Epoch #21 Error:0.9212499999999999
Epoch #22 Error:0.9212499999999999
Epoch #23 Error:0.9212499999999999
Epoch #24 Error:0.9212499999999999
Epoch #25 Error:0.9212499999999999
Epoch #26 Error:0.9212499999999999
Epoch #27 Error:0.9212499999999999
Epoch #28 Error:0.92125
Epoch #29 Error:0.9212499999999999
Epoch #30 Error:0.92125
Epoch #31 Error:0.9212499999999999
Epoch #32 Error:0.9212499999999999
输出:
1.0, 165.0, 73.0, actual=1.0, 1.0, ideal=2.5, 1.0
2.0, 385.0, 191.0, actual=1.0, 1.0, ideal=2.5, 1.0
3.0, 418.0, 405.0, actual=1.0, 1.0, ideal=1.5, 1.5
4.0, 63.0, 257.0, actual=1.0, 1.0, ideal=0.5, 2.5
5.0, 586.0, 5.0, actual=1.0, 1.0, ideal=2.5, 0.5
6.0, 159.0, 384.0, actual=1.0, 1.0, ideal=1.0, 2.5
7.0, 953.0, 153.0, actual=1.0, 1.0, ideal=2.5, 0.5
我的问题是,我该如何解决它,以降低传播误差并获得与“ 1”不同的输出?