我正在使用一个神经网络项目,该项目需要运行两个神经网络程序实例,同时具有两个不同的训练集。为此,我在java plus encog库中使用mutithreading来实现ANN。所以我创建了两个线程,每个线程包含ANN实现,但有两个不同的CSV文件。我有部分结果工作,它返回CSV文件中每个comlumn的最小值和最大值,但是只为一个文件计算ANN输出的问题。这是我的实施:
ReadfileMT.java
public class ReadFileMT implements Runnable {
public static void dumpFieldInfo(EncogAnalyst analyst) {
System.out.println("Fields found in file:");
for (AnalystField field : analyst.getScript().getNormalize()
.getNormalizedFields())
{
StringBuilder line = new StringBuilder();
line.append(field.getName());
line.append(",action=");
line.append(field.getAction());
line.append(",min=");
line.append(field.getActualLow());
line.append(",max=");
line.append(field.getActualHigh());
System.out.println(line.toString());
}
}
public void run() {
File sourceFile = new File("d:\\data\\F21.csv");
File targetFile = new File("d:\\data\\F2_norm.csv");
EncogAnalyst analyst = new EncogAnalyst();
AnalystWizard wizard = new AnalystWizard(analyst);
AnalystField targetField = wizard.getTargetField();
wizard.setTargetField("Old_Resp");
wizard.wizard(sourceFile, true, AnalystFileFormat.DECPNT_COMMA);
dumpFieldInfo(analyst);
final AnalystNormalizeCSV norm = new AnalystNormalizeCSV();
norm.analyze(sourceFile, true, CSVFormat.ENGLISH, analyst);
norm.setProduceOutputHeaders(true);
norm.normalize(targetFile);
// Encog.getInstance().shutdown();
//*****************************Read from the csv file**************************************************
final BasicNetwork network = EncogUtility.simpleFeedForward(4, 4, 0, 1,
false);
network.addLayer(new BasicLayer(new ActivationSigmoid(),false,4));
network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
network.addLayer(newBasicLayer(newActivationSigmoid(),false,1));
network.getStructure().finalizeStructure();
network.reset();
//create training data
final MLDataSet trainingSet = TrainingSetUtil.loadCSVTOMemory(
CSVFormat.ENGLISH, "c:\\temp\\F2_norm.csv",false, 4, 1);
// train the neural network
System.out.println();
System.out.println("Training Network");
final Backpropagationtrain=newBackpropagation
(network,trainingSet,0.05,0.9);
train.fixFlatSpot(false);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" +train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();
//final Train train=newResilientPropagation(network,trainingSet);
/*int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:"
+ train.getError() * 100 + "%");
epoch++;
} while (train.getError() > 0.015);*/
/*int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" +train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();*/
// test the neural network
System.out.println("Neural Network Results:");
for(MLDataPair pair: trainingSet ) {
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0)+",
"+pair.getInput().getData(1)+","+ pair.getInput().getData(2)
+","+ pair.getInput().getData(3)
+ ", actual=" + output.getData(0) + ",ideal="
+pair.getIdeal().getData(0));
}
Encog.getInstance().shutdown();
}
}
ReadFileMT2.java
public class ReadFileMT2 implements Runnable {
public static void dumpFieldInfo(EncogAnalyst analyst) {
System.out.println("Fields found in file:");
for (AnalystField field : analyst.getScript().getNormalize()
.getNormalizedFields())
{
StringBuilder line = new StringBuilder();
line.append(field.getName());
line.append(",action=");
line.append(field.getAction());
line.append(",min=");
line.append(field.getActualLow());
line.append(",max=");
line.append(field.getActualHigh());
System.out.println(line.toString());
}
}
public void run() {
File sourceFile = new File("d:\\data\\RespTime.csv");
File targetFile =newFile("d:\\data\\RespTime_norm.csv");
EncogAnalyst analyst = new EncogAnalyst();
AnalystWizard wizard = new AnalystWizard(analyst);
AnalystField targetField = wizard.getTargetField();
wizard.setTargetField("Old_Resp");
wizard.wizard(sourceFile, true, AnalystFileFormat.DECPNT_COMMA);
dumpFieldInfo(analyst);
final AnalystNormalizeCSV norm = new AnalystNormalizeCSV();
norm.analyze(sourceFile, true, CSVFormat.ENGLISH, analyst);
norm.setProduceOutputHeaders(true);
norm.normalize(targetFile);
// Encog.getInstance().shutdown();
//******Read from the csv file*************************
final BasicNetwork network = EncogUtility.simpleFeedForward(4, 4, 0, 1,
false);
network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
network.addLayer(newBasicLayer(newActivationSigmoid(),false,4));
network.addLayer(newBasicLayer(newActivationSigmoid(),false,1));
network.getStructure().finalizeStructure();
network.reset();
//create training data
final MLDataSet trainingSet = TrainingSetUtil.loadCSVTOMemory(
CSVFormat.ENGLISH, "c:\\temp\\RespTime_norm.csv",false, 4, 1);
// train the neural network
System.out.println();
System.out.println("Training Network");
final Backpropagation train = new Backpropagation
(network,trainingSet,0.05, 0.9);
train.fixFlatSpot(false);
int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" +train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();
/*int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:"
+ train.getError() * 100 + "%");
epoch++;
} while (train.getError() > 0.015);*/
/*int epoch = 1;
do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" +train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();*/
// test the neural network
System.out.println("Neural Network Results:");
for(MLDataPair pair: trainingSet ) {
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0) + ","
+pair.getInput().getData(1)+",
"+ pair.getInput().getData(2)+
","+ pair.getInput().getData(3)
+ ", actual=" +output.getData(0)+",ideal="+
pair.getIdeal().getData(0));
}
Encog.getInstance().shutdown();
}
}
main.java
public static void main(String[] args) {
ReadFileMT obj1 = new ReadFileMT();
ReadFileMT2 obj2 = new ReadFileMT2();
Thread t1 = new Thread(obj1);
Thread t2 = new Thread(obj2);
t1.start();
t2.start();
}
}
我不明白出了什么问题。 Ps:我刚开始参与parralel编程