DeepLearning4J –进行预测后更新训练有素的模型

时间:2019-05-13 11:36:01

标签: java neural-network deeplearning4j

我用DeepLearning4J建立了两个神经网络,一个前馈和一个递归网络。为了训练,测试和评估这两种网络,我使用以下代码段:

protected Evaluation trainAndTestNetwork(MultiLayerNetwork model, DataSetIterator trainData, DataSetIterator testData) {
    String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
    Evaluation evaluation = new Evaluation();
    for (int i = 0; i < nEpochs; i++) {
        model.fit(trainData);
        evaluation = model.evaluate(testData); // Evaluate on test set
        log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));
        testData.reset();
        trainData.reset();
    }
    return evaluation;
}

要在光盘上存储最终网络,我使用了以下方法:

public static void saveNeuralNet(String networkPath, Model model, DataNormalization normalizer) {
    try {
        File file = new File(networkPath + MODEL_NAME + ".zip"); // MODEL_NAME id model
        if (file.exists()) file.delete();
        file.createNewFile();
        ModelSerializer.writeModel(model, file, true, normalizer);
    } catch (IOException e) { e.printStackTrace(); }
}

使用以下方法还原模型后:

public static MultiLayerNetwork restoreMultiLayerNetwork(String networkDirectory) throws IOException {
    File file = new File(networkDirectory + "/" + MODEL_NAME + ".zip"); // MODEL_NAME id model
    return ModelSerializer.restoreMultiLayerNetwork(file, true);
}

我想对一个从未见过的数据集做出预测。因此,我还原了模型,得到了DataSetIterator并将其传递给模型的output方法。

MultiLayerNetwork model = Classifier.restoreMultiLayerNetwork(configFile.getNetworkPath() + configFile.getNetworkName());
dataSetIterator = getDataSetIterator(configFile, predictionData); // Method follows below
INDArray networkPrediction = model.output(dataSetIterator);

要将dataSetIterator传递到模型的output方法中,我必须将数据加载到CSVRecordReader中以进行前馈:

public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
    try {
        DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath, new Double[][] { Aggregator.mergeAggregation(predictionData) });
        CSVRecordReader dataSet = new CSVRecordReader();
        dataSet.initialize(new FileSplit(new File(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + DataFileWriter.FILE_FORMAT)));
        return new RecordReaderDataSetIterator(dataSet, configFile.getBatchSize());
    } catch (IOException | InterruptedException e) { e.printStackTrace(); }
    return null;
}

或进入SequenceRecordReader以获取循环网络:

public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
    try {

        DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath + "/0", predictionData);
        SequenceRecordReader dataSetFeatures = new CSVSequenceRecordReader();
        dataSetFeatures.initialize(new NumberedFileInputSplit(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + "/%d.csv", 0, 0));
        return new SequenceRecordReaderDataSetIterator(dataSetFeatures, configFile.getBatchSize(), configFile.getNumLabelClasses(), 0);
    } catch (IOException | InterruptedException e) { e.printStackTrace(); }
    return null;
}

我有四个问题:

  1. 第一个问题不是关于实际问题,而是关于使用测试集的经过训练的网络的测试和评估过程的一般问题。测试数据是自动存储在模型中还是必须使用测试数据集更新模型?

  2. 当我使用INDArray networkPrediction = model.output(dataSetIterator);时,如何更新模型?当我得到DataSetIterator时,我显然没有设置任何输出标签,但是当结果最终出现在现实世界中时,我必须更新当前模型。我只是做与getDataSetIterator(ConfigFile configFile, Double[][] predictionData)之前的操作一样,但是这次将输出写入数据集,并设置标签索引和可能的结果数量,例如new RecordReaderDataSetIterator(trainFeatures, batchSize, 0, numLabelClasses);吗?还是我必须从头开始运行整个网络,重新培训并使用添加的新数据再次测试所有内容?

  3. 第三个问题是关于在进行预测时如何使用SequenceRecordReaderDataSetIterator。尽管一切都适用于前馈的方法getDataSetIterator(ConfigFile configFile, Double[][] predictionData),但我仍在尝试使用循环网络的相同方法。 RecordReaderDataSetIterator的构造函数仅允许传递输入数据集和批处理大小,如果您要进行预测,这很有意义。 SequenceRecordReaderDataSetIterator的构造函数似乎需要可能的输出标签数量以及标签索引。这没有什么意义,因为还没有输出,这就是为什么我总是得到表示无效输出的异常Exception in thread "main" java.lang.NumberFormatException: For input string: "0.0"的原因。是否有另一种方法来获取正确的DataSetIterator以便使用循环网络进行预测,还是我会错过其他事情?

  4. 由于我还尝试了其他方法来从网络中获取预测,因此偶然发现了模型的方法predict(DataSet dataset)。所以最后一个问题是,List<String> networkPrediction = model.predict(dataSet);INDArray networkPrediction = model.output(dataSetIterator);相比有什么区别?

0 个答案:

没有答案