Deeplearning4j:无法获得线性索引> = 1

时间:2016-08-30 18:13:00

标签: java java-8 deeplearning4j

我试图使用deeplearning4j训练神经网络。但是我得到了这个错误信息,我无法解释:

java.lang.reflect.InvocationTargetException
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.codehaus.mojo.exec.ExecJavaMojo$1.run(ExecJavaMojo.java:294)
    at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.IllegalArgumentException: Unable to get linear index >= 1
    at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:3275)
    at org.deeplearning4j.eval.Evaluation.eval(Evaluation.java:197)
    at mypackage.myclass.main(Learn.java:77)

我的数据在csv文件中,它是64个数字(值为0,1,2,3),标签值为-1000到1000(浮点数)。

例如:

 2,3,2,2,1,1,2,3,0,1,1,2,3,1,1,0,0,0,2,2,0,0,3,1,0,1,3,1,1,1,2,2,2,2,2,2,3, 2,2,2,2,3,3,1,2,2,1,3,0,0,2,3,2,3,2,0,0,3,0,1,1,3,3,2,-228.0

我使用此代码加载csv文件并训练网络:

RecordReader recordReader = new CSVRecordReader(0, ",");
recordReader.initialize(new FileSplit(new File("data.csv")));

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, new DoubleWritableCo    nverter(), 600000, 64, 64, true);

     DataSet allData = iterator.next();
     SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.9);

     DataSet trainingData = testAndTrain.getTrain();
     DataSet testData = testAndTrain.getTest();

     //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit     variance):
     DataNormalization normalizer = new NormalizerStandardize();
     normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the traini    ng data. This does not modify the input data
     normalizer.transform(trainingData);     //Apply normalization to the training data
     normalizer.transform(testData);         //Apply normalization to the test data. This is using     statistics calculated from the *training* set
     long seed = 123;
     int inputNum = 64;
     int hiddenNum = 64;
     int outputNum = 1;
     int iterations = 1;

     MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
         .seed(seed)
         .activation("tanh")
         .iterations(iterations)
         .weightInit(WeightInit.XAVIER)
         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
         .learningRate(0.1)
         .regularization(true).l2(1e-4)
         .list()
         .layer(0, new DenseLayer.Builder().nIn(inputNum).nOut(hiddenNum).build())
         .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
             .activation("identity")
             .nIn(hiddenNum).nOut(outputNum).build())
         .backprop(true).pretrain(false)
         .build();
     MultiLayerNetwork model = new MultiLayerNetwork(conf);
     model.init();
     model.setListeners(new ScoreIterationListener(100));

     model.fit(trainingData); 

     //evaluate the model on the test set
     Evaluation eval = new Evaluation(2);
     INDArray output = model.output(testData.getFeatureMatrix());
     eval.eval(testData.getLabels(), output); <---- this is line 77, where the error occurs
     System.out.println(eval.stats());
     recordReader.close();

这个错误意味着什么,我该如何解决?

1 个答案:

答案 0 :(得分:1)

您为评估指定了2个标签。然后在回归问题上使用eval。

Eval主要用于分类。您必须手动进行自己的回归评估。