FANN不训练

时间:2013-12-23 01:43:32

标签: machine-learning neural-network

我使用FANN进行函数逼近。我的代码在这里:

/*
 * File:   main.cpp
 * Author: johannsebastian
 *
 * Created on November 26, 2013, 8:50 PM
 */

#include "../FANN-2.2.0-Source/src/include/doublefann.h"
#include "../FANN-2.2.0-Source/src/include/fann_cpp.h"
//#include <doublefann>
//#include <fann/fann_cpp>
#include <cstdlib>
#include <iostream>


using namespace std;
using namespace FANN;

//Remember: fann_type is double!

int main(int argc, char** argv) {
    //create a test network: [1,2,1] MLP
    neural_net * net = new neural_net;
    const unsigned int layers[3] = {1, 2, 1};
    net->create_standard_array(3, layers);

    //net->create_standard(num_layers, num_input, num_hidden, num_output);

    //net->set_learning_rate(0.7f);

    //net->set_activation_steepness_hidden(0.7);
    //net->set_activation_steepness_output(0.7);

    net->set_activation_function_hidden(SIGMOID);
    net->set_activation_function_output(SIGMOID);
    net->set_training_algorithm(TRAIN_RPROP);

    //cout<<net->get_train_error_function()
    //exit(0);
    //test the number 2
    fann_type * testinput = new fann_type;
    *testinput = 2;
    fann_type * testoutput = new fann_type;
    *testoutput = *(net->run(testinput));
    double outputasdouble = (double) *testoutput;
    cout << "Test output: " << outputasdouble << endl;

    //make a training set of x->x^2
    training_data * squaredata = new training_data;
    squaredata->read_train_from_file("trainingdata.txt");
    //cout<<testinput[0]<<endl;
    //cout<<testoutput[0]<<endl;
    cout<<*(squaredata->get_input())[9]<<endl;
    cout<<*(squaredata->get_output())[9]<<endl;
    cout<<squaredata->length_train_data();

    //scale data
    fann_type * scaledinput = new fann_type[squaredata->length_train_data()];
    fann_type * scaledoutput = new fann_type[squaredata->length_train_data()];
    for (unsigned int i = 0; i < squaredata->length_train_data(); i++) {
            scaledinput[i] = *squaredata->get_input()[i]/200;///100;
            scaledoutput[i] = *squaredata->get_output()[i]/200;///100;
            cout<<"In:\t"<<scaledinput[i]<<"\t Out:\t"<<scaledoutput[i]<<endl;
    }

    net->train_on_data(*squaredata, 1000000, 100000, 0.001);

    *testoutput = *(net->run(testinput));
    outputasdouble = (double) *testoutput;
    cout << "Test output: " << outputasdouble << endl;

    cout << endl << "Easy!";
    return 0;
}

这是trainingdata.txt:

10 1 1
1 1
2 4
3 9
4 16
5 25
6 36
7 49
8 64
9 81
10 100

当我跑步时,我得到了这个:

Test output: 0.491454
10
100
10In:   0.005    Out:   0.005
In:     0.01     Out:   0.02
In:     0.015    Out:   0.045
In:     0.02     Out:   0.08
In:     0.025    Out:   0.125
In:     0.03     Out:   0.18
In:     0.035    Out:   0.245
In:     0.04     Out:   0.32
In:     0.045    Out:   0.405
In:     0.05     Out:   0.5
Max epochs  1000000. Desired error: 0.0010000000.
Epochs            1. Current error: 2493.7961425781. Bit fail 10.
Epochs       100000. Current error: 2457.3000488281. Bit fail 9.
Epochs       200000. Current error: 2457.3000488281. Bit fail 9.
Epochs       300000. Current error: 2457.3000488281. Bit fail 9.
Epochs       400000. Current error: 2457.3000488281. Bit fail 9.
Epochs       500000. Current error: 2457.3000488281. Bit fail 9.
Epochs       600000. Current error: 2457.3000488281. Bit fail 9.
Epochs       700000. Current error: 2457.3000488281. Bit fail 9.
Epochs       800000. Current error: 2457.3000488281. Bit fail 9.
Epochs       900000. Current error: 2457.3000488281. Bit fail 9.
Epochs      1000000. Current error: 2457.3000488281. Bit fail 9.
Test output: 1

Easy!
RUN FINISHED; exit value 0; real time: 9s; user: 10ms; system: 4s

为什么培训不起作用?在我问similar question之后,我被告知要缩放NN的输入和输出。我已经这样做了。我得到一些错误的参数,还是我只需要训练更长时间?

2 个答案:

答案 0 :(得分:1)

隐藏图层中的节点编号太少,不适合二次函数。我会尝试10。 此外,我想建议你fun applet,你可以通过参数设置模拟训练过程。我尝试了10个隐藏层节点和单极sigmoid作为隐藏层和输出层激活函数,拟合也不错(但随机化权重可能导致收敛失败,因此强烈建议隐藏层中的更多节点,你可以尝试自己玩这个小程序并观察一些有趣的观点):

enter image description here

答案 1 :(得分:0)

也许有点晚了,但也许新的FANN初学者会看到这个答案,我希望这会有所帮助!

我认为您的问题来自trainingdata.txt:

中的数据格式

见: FANN data format

每次输入和每次输出后都必须换行。

在您的情况下,您有10个示例,1个输入和1个输出。然后,您必须格式化您的文件:

10 1 1
1 
1
2 
4
3 
9
4 
16
5 
25
6 
36
...

注意:我注意到当数据格式错误时,训练方法计算的误差非常(非常)高。当您看到巨大的错误值时,可能会提示您查看文件格式。