Encog神经网络训练不正确

时间:2019-08-12 16:45:08

标签: java machine-learning deep-learning encog

我正在训练我的网络,但不会出现0.9212499999999999错误。并将相同的输出带到所有输入。

  JAVAML jml = new JAVAML();                

    // create a neural network, without using a factory
    BasicNetwork network = new BasicNetwork();
    network.addLayer(new BasicLayer(null,true,3));
    network.addLayer(new BasicLayer(new ActivationReLU(),true,50));
    network.addLayer(new BasicLayer(new ActivationSigmoid(),false,2));
    network.getStructure().finalizeStructure();
    network.reset();

    // create training data
    MLDataSet dataSet = new BasicMLDataSet(jml.getX(), jml.getY());

    // train the neural network
    final ResilientPropagation train = new ResilientPropagation(network, dataSet);

    int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" + train.getError());
        epoch++;
    } while(train.getError() > 0.01 & epoch < 3500);
    train.finishTraining();

    // test the neural network
    System.out.println("Neural Network Results:");
    for(MLDataPair pair: dataSet ) {
        final MLData output = network.compute(pair.getInput());
        System.out.println(pair.getInput().getData(0) + ", " + pair.getInput().getData(1)+ ", " + pair.getInput().getData(2)
                + ", actual=" + output.getData(0) +", "+ output.getData(1) 
                                    + ", ideal=" + pair.getIdeal().getData(0)+", "+ pair.getIdeal().getData(1));
    }

    Encog.getInstance().shutdown();

数据集示例首先输入3个,最后输出2个:

1.0 68.0    256.0   1.0 2.5
2.0 982.0   102.0   2.5 0.5
3.0 821.0   354.0   2.5 1.0
4.0 772.0   204.0   2.5 1.0
5.0 115.0   235.0   1.0 2.5
6.0 824.0   179.0   2.5 0.5
7.0 775.0   258.0   2.5 1.0

传播错误

Epoch #14 Error:0.9212500000000001
Epoch #15 Error:0.9212499999999999
Epoch #16 Error:0.9212499999999999
Epoch #17 Error:0.9212499999999999
Epoch #18 Error:0.9212499999999999
Epoch #19 Error:0.9212500000000001
Epoch #20 Error:0.92125
Epoch #21 Error:0.9212499999999999
Epoch #22 Error:0.9212499999999999
Epoch #23 Error:0.9212499999999999
Epoch #24 Error:0.9212499999999999
Epoch #25 Error:0.9212499999999999
Epoch #26 Error:0.9212499999999999
Epoch #27 Error:0.9212499999999999
Epoch #28 Error:0.92125
Epoch #29 Error:0.9212499999999999
Epoch #30 Error:0.92125
Epoch #31 Error:0.9212499999999999
Epoch #32 Error:0.9212499999999999

输出:

1.0, 165.0, 73.0, actual=1.0, 1.0, ideal=2.5, 1.0
2.0, 385.0, 191.0, actual=1.0, 1.0, ideal=2.5, 1.0
3.0, 418.0, 405.0, actual=1.0, 1.0, ideal=1.5, 1.5
4.0, 63.0, 257.0, actual=1.0, 1.0, ideal=0.5, 2.5
5.0, 586.0, 5.0, actual=1.0, 1.0, ideal=2.5, 0.5
6.0, 159.0, 384.0, actual=1.0, 1.0, ideal=1.0, 2.5
7.0, 953.0, 153.0, actual=1.0, 1.0, ideal=2.5, 0.5

我的问题是,我该如何解决它,以降低传播误差并获得与“ 1”不同的输出?

0 个答案:

没有答案