多线程java和神经网络:同时读取2个csv文件

时间:2015-10-08 19:16:26

标签: java multithreading encog

我正在使用一个神经网络项目,该项目需要运行两个神经网络程序实例,同时具有两个不同的训练集。为此,我在java plus encog库中使用mutithreading来实现ANN。所以我创建了两个线程,每个线程包含ANN实现,但有两个不同的CSV文件。我有部分结果工作,它返回CSV文件中每个comlumn的最小值和最大值,但是只为一个文件计算ANN输出的问题。这是我的实施:

ReadfileMT.java

    public class ReadFileMT implements Runnable {


    public static void dumpFieldInfo(EncogAnalyst analyst) {

     System.out.println("Fields found in file:");
    for (AnalystField field : analyst.getScript().getNormalize()
            .getNormalizedFields())
          {

        StringBuilder line = new StringBuilder();
        line.append(field.getName());
        line.append(",action=");
        line.append(field.getAction());
        line.append(",min=");
        line.append(field.getActualLow());
        line.append(",max=");
        line.append(field.getActualHigh());
        System.out.println(line.toString());
    }
}

     public void run() {
     File sourceFile = new File("d:\\data\\F21.csv");
                    File targetFile = new File("d:\\data\\F2_norm.csv"); 
        EncogAnalyst analyst = new EncogAnalyst();
                AnalystWizard wizard = new AnalystWizard(analyst);
                    AnalystField targetField = wizard.getTargetField();
                    wizard.setTargetField("Old_Resp");
        wizard.wizard(sourceFile, true, AnalystFileFormat.DECPNT_COMMA);
        dumpFieldInfo(analyst);
        final AnalystNormalizeCSV norm = new AnalystNormalizeCSV();
        norm.analyze(sourceFile, true, CSVFormat.ENGLISH, analyst);
                    norm.setProduceOutputHeaders(true);
                    norm.normalize(targetFile);
                   // Encog.getInstance().shutdown();
        //*****************************Read from the csv file**************************************************               



    final BasicNetwork network = EncogUtility.simpleFeedForward(4,  4, 0, 1,
                            false);

     network.addLayer(new BasicLayer(new     ActivationSigmoid(),false,4));
     network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
     network.addLayer(newBasicLayer(newActivationSigmoid(),false,1));
     network.getStructure().finalizeStructure();
     network.reset();
            //create training data
            final MLDataSet trainingSet = TrainingSetUtil.loadCSVTOMemory(
            CSVFormat.ENGLISH, "c:\\temp\\F2_norm.csv",false, 4, 1);

            // train the neural network
            System.out.println();
            System.out.println("Training Network");
     final   Backpropagationtrain=newBackpropagation
     (network,trainingSet,0.05,0.9);
            train.fixFlatSpot(false);
            int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" +train.getError());
        epoch++;
    } while(train.getError() > 0.01);
    train.finishTraining();
           //final Train train=newResilientPropagation(network,trainingSet);
           /*int epoch = 1;
        do {
            train.iteration();
            System.out.println("Epoch #" + epoch + " Error:"
                    + train.getError() * 100 + "%");
            epoch++;
        } while (train.getError() > 0.015);*/

            /*int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" +train.getError());
        epoch++;
    } while(train.getError() > 0.01);
    train.finishTraining();*/




            // test the neural network


            System.out.println("Neural Network Results:");
    for(MLDataPair pair: trainingSet ) {
        final MLData output = network.compute(pair.getInput());
         System.out.println(pair.getInput().getData(0)+",
         "+pair.getInput().getData(1)+","+ pair.getInput().getData(2)
         +","+   pair.getInput().getData(3) 
         + ", actual=" + output.getData(0) + ",ideal="
        +pair.getIdeal().getData(0));
    }

    Encog.getInstance().shutdown();

}
        }

ReadFileMT2.java

    public class ReadFileMT2 implements Runnable {


    public static void dumpFieldInfo(EncogAnalyst analyst) {

     System.out.println("Fields found in file:");
    for (AnalystField field : analyst.getScript().getNormalize()
            .getNormalizedFields())
          {

        StringBuilder line = new StringBuilder();
        line.append(field.getName());
        line.append(",action=");
        line.append(field.getAction());
        line.append(",min=");
        line.append(field.getActualLow());
        line.append(",max=");
        line.append(field.getActualHigh());
        System.out.println(line.toString());
     }
    }


     public void run() {
     File sourceFile = new File("d:\\data\\RespTime.csv");
     File targetFile =newFile("d:\\data\\RespTime_norm.csv"); 
        EncogAnalyst analyst = new EncogAnalyst();
                AnalystWizard wizard = new AnalystWizard(analyst);
                    AnalystField targetField = wizard.getTargetField();
                    wizard.setTargetField("Old_Resp");
        wizard.wizard(sourceFile, true, AnalystFileFormat.DECPNT_COMMA);
        dumpFieldInfo(analyst);
        final AnalystNormalizeCSV norm = new AnalystNormalizeCSV();
        norm.analyze(sourceFile, true, CSVFormat.ENGLISH, analyst);
                    norm.setProduceOutputHeaders(true);
                    norm.normalize(targetFile);
                   // Encog.getInstance().shutdown();
        //******Read from the csv file*************************              



    final BasicNetwork network = EncogUtility.simpleFeedForward(4, 4, 0, 1,
                            false);

            network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
            network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
            network.addLayer(newBasicLayer(newActivationSigmoid(),false,1));
            network.getStructure().finalizeStructure();
            network.reset();
            //create training data
            final MLDataSet trainingSet = TrainingSetUtil.loadCSVTOMemory(
            CSVFormat.ENGLISH, "c:\\temp\\RespTime_norm.csv",false, 4, 1);

            // train the neural network
            System.out.println();
            System.out.println("Training Network");

            final Backpropagation train = new Backpropagation
            (network,trainingSet,0.05, 0.9);
            train.fixFlatSpot(false);
            int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" +train.getError());
        epoch++;
     } while(train.getError() > 0.01);
     train.finishTraining();

           /*int epoch = 1;
        do {
            train.iteration();
            System.out.println("Epoch #" + epoch + " Error:"
                    + train.getError() * 100 + "%");
            epoch++;
        } while (train.getError() > 0.015);*/

            /*int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" +train.getError());
        epoch++;
    } while(train.getError() > 0.01);
    train.finishTraining();*/




            // test the neural network


            System.out.println("Neural Network Results:");
    for(MLDataPair pair: trainingSet ) {
        final MLData output = network.compute(pair.getInput());
        System.out.println(pair.getInput().getData(0) + ","
        +pair.getInput().getData(1)+",
        "+ pair.getInput().getData(2)+
        ","+ pair.getInput().getData(3) 
        + ", actual=" +output.getData(0)+",ideal="+
        pair.getIdeal().getData(0));
    }

    Encog.getInstance().shutdown();

      }
        }

main.java

    public static void main(String[] args) {

       ReadFileMT obj1 = new ReadFileMT();
       ReadFileMT2 obj2 = new ReadFileMT2();
       Thread t1 = new Thread(obj1);
       Thread t2 = new Thread(obj2);
        t1.start();
        t2.start();

             }
      }         

我不明白出了什么问题。 Ps:我刚开始参与parralel编程

0 个答案:

没有答案