我正在尝试使用深度learning4j框架构建神经网络,我收到以下错误:
java.lang.IllegalStateException: Unexpected state occurred for AsyncDataSetIterator: runnable died or no data available" this exception
这是我的代码
package com.neuralnetwork;
import com.sliit.preprocessing.NormalizeDataset;
import com.sliit.ruleengine.RuleEngine;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
import weka.core.converters.CSVSaver;
import weka.filters.Filter;
import weka.filters.supervised.instance.StratifiedRemoveFolds;
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* Deep Belif Neural Network to detect frauds.
*/
public class FraudDetectorDeepBeilefNN {
private static final Log log = LogFactory.getLog(FraudDetectorDeepBeilefNN.class);
public int outputNum = 4;
private int iterations = 5;
private int seed = 1234;
private MultiLayerNetwork model = null;
public int HIDDEN_LAYER_COUNT = 8;
public int numHiddenNodes = 21;
public int inputs = 41;
private String uploadDirectory = "D:/Data";
private ArrayList < Map < String, Double >> roc;
public FraudDetectorDeepBeilefNN() {
}
public void buildModel() {
System.out.println("Build model....");
iterations = outputNum + 1;
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(iterations);
builder.learningRate(0.001);
// builder.momentum(0.01);
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
builder.seed(seed);
builder.biasInit(1);
builder.regularization(true).l2(1e-5);
builder.updater(Updater.RMSPROP);
builder.weightInit(WeightInit.XAVIER);
NeuralNetConfiguration.ListBuilder list = builder.list();
for (int i = 0; i < HIDDEN_LAYER_COUNT; i++) {
GravesLSTM.Builder hiddenLayerBuilder = new GravesLSTM.Builder();
hiddenLayerBuilder.nIn(i == 0 ? inputs : numHiddenNodes);
hiddenLayerBuilder.nOut(numHiddenNodes);
hiddenLayerBuilder.activation("tanh");
list.layer(i, hiddenLayerBuilder.build());
}
RnnOutputLayer.Builder outputLayer = new RnnOutputLayer.Builder(LossFunction.MCXENT);
outputLayer.activation("softmax");
outputLayer.nIn(numHiddenNodes);
outputLayer.nOut(outputNum);
list.layer(HIDDEN_LAYER_COUNT, outputLayer.build());
list.pretrain(false);
list.backprop(true);
MultiLayerConfiguration configuration = list.build();
model = new MultiLayerNetwork(configuration);
model.init();
//model.setListeners(new ScoreIterationListener(1));
}
public String trainModel(String modelName, String filePath, int outputs, int inputsTot) throws NeuralException {
try {
System.out.println("Neural Network Training start");
loadSaveNN(modelName, false);
if (model == null) {
buildModel();
}
System.out.println("modal" + model);
System.out.println("file path " + filePath);
File fileGeneral = new File(filePath);
CSVLoader loader = new CSVLoader();
loader.setSource(fileGeneral);
Instances instances = loader.getDataSet();
instances.setClassIndex(instances.numAttributes() - 1);
StratifiedRemoveFolds stratified = new StratifiedRemoveFolds();
String[] options = new String[6];
options[0] = "-N";
options[1] = Integer.toString(5);
options[2] = "-F";
options[3] = Integer.toString(1);
options[4] = "-S";
options[5] = Integer.toString(1);
stratified.setOptions(options);
stratified.setInputFormat(instances);
stratified.setInvertSelection(false);
Instances testInstances = Filter.useFilter(instances, stratified);
stratified.setInvertSelection(true);
Instances trainInstances = Filter.useFilter(instances, stratified);
String directory = fileGeneral.getParent();
CSVSaver saver = new CSVSaver();
File trainFile = new File(directory + "/" + "normtrainadded.csv");
File testFile = new File(directory + "/" + "normtestadded.csv");
if (trainFile.exists()) {
trainFile.delete();
}
trainFile.createNewFile();
if (testFile.exists()) {
testFile.delete();
}
testFile.createNewFile();
saver.setFile(trainFile);
saver.setInstances(trainInstances);
saver.writeBatch();
saver = new CSVSaver();
saver.setFile(testFile);
saver.setInstances(testInstances);
saver.writeBatch();
SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
recordReader.initialize(new org.datavec.api.split.FileSplit(trainFile));
SequenceRecordReader testReader = new CSVSequenceRecordReader(0, ",");
testReader.initialize(new org.datavec.api.split.FileSplit(testFile));
DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(recordReader, 2, outputs, inputsTot, false);
DataSetIterator testIterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(testReader, 2, outputs, inputsTot, false);
roc = new ArrayList < Map < String, Double >> ();
String statMsg = "";
Evaluation evaluation;
for (int i = 0; i < 100; i++) {
if (i % 2 == 0) {
model.fit(iterator);
evaluation = model.evaluate(testIterator);
} else {
model.fit(testIterator);
evaluation = model.evaluate(iterator);
}
Map < String, Double > map = new HashMap < String, Double > ();
Map < Integer, Integer > falsePositives = evaluation.falsePositives();
Map < Integer, Integer > trueNegatives = evaluation.trueNegatives();
Map < Integer, Integer > truePositives = evaluation.truePositives();
Map < Integer, Integer > falseNegatives = evaluation.falseNegatives();
double fpr = falsePositives.get(1) / (falsePositives.get(1) + trueNegatives.get(1));
double tpr = truePositives.get(1) / (truePositives.get(1) + falseNegatives.get(1));
map.put("FPR", fpr);
map.put("TPR", tpr);
roc.add(map);
statMsg = evaluation.stats();
iterator.reset();
testIterator.reset();
}
loadSaveNN(modelName, true);
System.out.println("ROC " + roc);
return statMsg;
} catch (Exception e) {
e.printStackTrace();
System.out.println("Error ocuured while building neural netowrk :" + e.getMessage());
throw new NeuralException(e.getLocalizedMessage(), e);
}
}
public boolean generateModel(String modelName) {
boolean status = false;
try {
loadSaveNN(modelName, true);
status = true;
} catch (Exception e) {
System.out.println("Error occurred:" + e.getLocalizedMessage());
}
return status;
}
private void loadSaveNN(String name, boolean save) {
File directory = new File(uploadDirectory);
File[] allNN = directory.listFiles();
boolean status = false;
try {
if (model == null && save) {
buildModel();
}
if (allNN != null && allNN.length > 0) {
for (File NN: allNN) {
String fnme = FilenameUtils.removeExtension(NN.getName());
if (name.equals(fnme)) {
status = true;
if (save) {
ModelSerializer.writeModel(model, NN, true);
System.out.println("Model Saved With Weights Successfully");
} else {
model = ModelSerializer.restoreMultiLayerNetwork(NN);
}
break;
}
}
}
if (!status && save) {
//File tempFIle = File.createTempFile(name,".zip",directory);
File tempFile = new File(directory.getAbsolutePath() + "/" + name + ".zip");
if (!tempFile.exists()) {
tempFile.createNewFile();
}
ModelSerializer.writeModel(model, tempFile, true);
}
} catch (IOException e) {
System.out.println("Error occurred:" + e.getMessage());
}
}
public String testModel(String modelName, String[] rawData, Map < Integer, String > map, int inputs, int outputs, String ruleModelSavePath) throws Exception {
String status = "";
String fpath = uploadDirectory;
FileWriter fwriter = new FileWriter(uploadDirectory + "original/insertdata.csv", true);
fwriter.write("\n");
fwriter.write(rawData[0]);
fwriter.close();
if (model == null) {
loadSaveNN(modelName, false);
}
NormalizeDataset norm = new NormalizeDataset(uploadDirectory + "original/insertdata.csv");
norm.updateStringValues(map);
norm.whiteningData();
norm.normalizeDataset();
BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(uploadDirectory + "originalnorminsertdata.csv")));
String output = "";
String prevOutput = "";
while ((output = bufferedReader.readLine()) != null) {
prevOutput = output;
}
bufferedReader.close();
File readFile = new File(uploadDirectory + "normtest.csv");
if (readFile.exists()) {
readFile.delete();
}
readFile.createNewFile();
PrintWriter writer = new PrintWriter(readFile);
writer.println(prevOutput);
writer.flush();
writer.close();
SequenceRecordReader recordReader = new CSVSequenceRecordReader(0, ",");
recordReader.initialize(new org.datavec.api.split.FileSplit(new File(uploadDirectory + "normtest.csv")));
DataSetIterator iterator = new org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator(recordReader, 2, outputs, inputs, false);
INDArray outputArr = null;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
INDArray provided = ds.getFeatureMatrix();
outputArr = model.rnnTimeStep(provided);
}
//INDArray finalOutput = Nd4j.argMax(outputArr,1);
double result = Double.parseDouble(Nd4j.argMax(outputArr, 1).toString());
if (result == 0.0) {
status = "Normal Transaction";
} else {
status = "Fraud Transaction, ";
bufferedReader = new BufferedReader(new FileReader(new File(uploadDirectory + "original/insertdata.csv")));
String heading = "";
heading = bufferedReader.readLine();
bufferedReader.close();
File ruleFile = new File(uploadDirectory + "normrules.csv");
if (ruleFile.exists()) {
ruleFile.delete();
}
ruleFile.createNewFile();
PrintWriter writeNew = new PrintWriter(ruleFile);
writeNew.println(heading);
writeNew.println(rawData[0]);
writeNew.flush();
writeNew.close();
RuleEngine engine = new RuleEngine(fpath + "original/insertdata.csv");
engine.geneateModel(ruleModelSavePath, false);
String finalStatus = status + "Attack Type:" + engine.predictionResult(uploadDirectory + "normrules.csv");
status = finalStatus;
}
return status;
}
public void setUploadDirectory(String uploadDirectory) {
this.uploadDirectory = uploadDirectory;
}
public static void main(String[] args) {
FraudDetectorDeepBeilefNN neural_network = new FraudDetectorDeepBeilefNN();
System.out.println("start=======================");
try {
neural_network.inputs = 4;
neural_network.numHiddenNodes = 3;
neural_network.HIDDEN_LAYER_COUNT = 2;
neural_network.outputNum = 2;
neural_network.buildModel();
String output = neural_network.trainModel("nn", "D:/Data/a.csv", 2, 4);
System.out.println(output);
} catch (Exception e) {
e.printStackTrace();
}
}
public ArrayList < Map < String, Double >> getRoc() {
return roc;
}
}
以下是我数据集的前几行:
length,caplen,hlen,version,output
60,60,6,0,normal
243,243,8,0,normal
60,60,6,0,neptune
60,60,6,0,neptune