Encog ::非收敛错误率

时间:2015-05-20 11:57:01

标签: java neural-network encog

我正在使用(encog 3.3.0库)构建用于图像识别的神经网络。我已经将我的图像转换为50x50灰度,以避免混淆我的神经网络,因为我基本上想从图像中进行一些独立于colours的特征提取。我有两个输出类。

我的输入::一个CSV文件,包含318行,每行有2502列。每行对应一个图像。前2500列是图像的50x50像素,后2列是输出类。 输入有159行,其中有2500个正常图像像素,然后是1.0作为输出,159行具有2500个正常图像像素,然后是0,1作为输出。 0表示它不属于,1表示它属于该类。

我的输入:: 318行和2502列。下面是其中一行::

255,243,251,255,244,255,235,67,51,52,53,54,54,54,53,53,53,54,55,55,.......,53,54,54,53,53,52 ,54,54,54,54,54,54,54,54,57,57,57,57,57,57,57,57,57,57,0,1

最后的0,1代表输出类。

我的图层::我有3层。输入层有2500个神经元,隐藏层有1000个神经元,输出层有2个神经元。

问题:当我开始训练具有学习率0.7和动量0.8的网络时,即使在100次迭代之后,误差率也不会收敛并且连续振荡在0.45-0.5附近。

以下是我的代码::

public class image_recognition {

static final int COLUMNS = 2500;
static final int OUTPUT = 2;

public BasicNetwork network;
public double[][] input;
public double[][] ideal;
    public MLDataSet trainingSet;
public void createNetwork() {
    network = new BasicNetwork();
            //simpleFeedForward(int input, int hidden1, int hidden2, int output, boolean tanh) 
    network = EncogUtility.simpleFeedForward(image_recognition.COLUMNS, 1000, 0, image_recognition.OUTPUT, false);
    network.reset();
}

public void train() {
            //BasicMLDataSet(double[][] input, double[][] ideal) 
            trainingSet = new BasicMLDataSet(input, ideal);
            //Backpropagation(ContainsFlat network, MLDataSet training, double learnRate, double momentum) 
    final Backpropagation train = new Backpropagation(network, trainingSet, 0.7, 0.8);

    int epoch = 1;


    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" + train.getError());
                    long time = System.currentTimeMillis();
        System.out.println("after iteration time :: ");
                    System.out.println(time);
        epoch++;
    } while ((epoch < 5000) && (train.getError() > 0.3));


}

public double evaluate() {

    System.out.println("Neural Network Results:");
    for(MLDataPair pair: trainingSet ) {
        final MLData output = network.compute(pair.getInput());
                    String actualoutput1 = String.format("%.6f", output.getData(0));
                    String idealoutput1 = String.format("%.1f", pair.getIdeal().getData(0));
                    String actualoutput2 = String.format("%.6f", output.getData(1));
                    String idealoutput2 = String.format("%.1f", pair.getIdeal().getData(1));
        System.out.println("actual1 = " + actualoutput1  + ", actual2 = " + actualoutput2 + " ,ideal1 = " + idealoutput1 + "  ,ideal2 = " + idealoutput2  );
    }
    return 0;
}

public void load(String filename) throws IOException {
    int size = 0;

    ReadCSV csv;
    csv = new ReadCSV(filename, false, CSVFormat.DECIMAL_POINT);
    while (csv.next()) {
        size++;
    }
    csv.close();

    // allocate enough space
    input = new double[size][image_recognition.COLUMNS];
    ideal = new double[size][image_recognition.OUTPUT];

    // now load it
    int index = 0;
    csv = new ReadCSV(filename, false, CSVFormat.DECIMAL_POINT);
    while (csv.next()) {
                for(int i=0;i<image_recognition.COLUMNS;i++)
                {
                    input[index][i] = Double.parseDouble(csv.get(i));
                }
                for(int i=0;i<image_recognition.OUTPUT;i++)
                {
        ideal[index][i] = Double.parseDouble(csv.get(image_recognition.COLUMNS+i));
                }
        index++;
    }
    csv.close();

}

public static void main(final String args[]) {
    try {
        image_recognition prg = new image_recognition();
                    long b1 = System.currentTimeMillis();
                    System.out.println("before loading time :: ");
                    System.out.println(b1);
        prg.load("mycsv.csv");
                    long a1 = System.currentTimeMillis();
                    System.out.println("after loading, before creating network time :: ");
                    System.out.println(a1);
        prg.createNetwork();
                    long a2 = System.currentTimeMillis();
                    System.out.println("after creating network, before training time :: ");
                    System.out.println(a2);
        prg.train();
                    long a3 = System.currentTimeMillis();
                    System.out.println("after training, before testing time :: ");
                    System.out.println(a3);
                    prg.evaluate();
    } catch (Throwable t) {
        t.printStackTrace();
    }

}

}

我的输出::

Epoch#1错误:0.48833917036172103

Epoch#2错误:0.5

时代#3错误:0.5

Epoch#4错误:0.5

Epoch#5错误:0.45956570930539425

.........

Epoch#23错误:0.4744859426599884

Epoch#24错误:0.5

大纪元#25错误:0.5

...........

Epoch#49错误:0.5912731593753425

Epoch#50错误:0.5

大纪元#51错误:0.5031968130459842

...........

Epoch#71错误:0.5046318360708989

Epoch#72错误:0.49357338328109024

大纪元#73错误:0.486820369587797

...........

Epoch#103错误:0.5155249407683976

Epoch#104错误:0.4835673679113441

Epoch#105错误:0.49407335871268354

.........

Epoch#142错误:0.49038913805594664

Epoch#143错误:0.4660191340060382

请指导我为什么错误率不会收敛。我已经尝试过运行它以进行更多迭代,但它仍然没有收敛。我需要错误至少为0.1。

1 个答案:

答案 0 :(得分:0)

我想澄清一下。首先,您在这里使用哪种激活功能?二,激活功能的参数是什么?三,源图片的初始大小是多少?第四,如果图片的初始尺寸不平方,也许可以很好地从50x50切换到70x30