经过训练的回归神经网络的输出对于所有条目都是相同的

时间:2019-04-26 12:09:15

标签: deeplearning4j

我已经测试并训练了神经网络进行回归分析。当我使用新数据运行网络时,每个条目都获得相同的数字。我从分类系统改编而成,并且奏效了。

代码是:

private void writeINDArray(INDArray output, PrintWriter writer, Iterator<String> identifierIterator) {
    int rows = output.rows();
    int coluumns = output.columns();

    for (int i = 0; i < rows; i++) {
        INDArray row = output.getRow(i);
        StringJoiner stringJoiner = new StringJoiner("\t");

        for (int j = 0; j < coluumns; j++) {
            stringJoiner.add(Float.toString(row.getFloat(j)));
        }

        if (identifierIterator.hasNext()) {
            stringJoiner.add(identifierIterator.next());
        }
        else {
            throw new RuntimeException("identifier list is empty!");
        }
        writer.println(stringJoiner.toString());
        log.info(stringJoiner);
    }
}

@Override
public void run(File neuralNetworkZipFile, File fingerPrintFile, List<String> identifiers) {
    log.info(String.format("running %s on %s", neuralNetworkZipFile.getAbsolutePath(), fingerPrintFile.getAbsolutePath()));

    Iterator<String> identifierIterator = identifiers.iterator();

    runResultFile = new File("run_results_" + Utility.timeDate() + ".txt");

    try (RecordReader recordReader = new CSVRecordReader(0, ','); PrintWriter writer = new PrintWriter(runResultFile)) {
        recordReader.initialize(new FileSplit(fingerPrintFile));

        DataSetIterator iterator = neuralNetworkSupporter.getDataSetIterator(recordReader);
        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(neuralNetworkZipFile);

        while (iterator.hasNext()) {
            DataSet fingerPrint = iterator.next();
            INDArray output = model.output(fingerPrint.getFeatures(), false);

            writeINDArray(output, writer, identifierIterator);
        }
    }
    catch (IOException | InterruptedException e) {
        e.printStackTrace();
    }
}

任何建议我做错了什么?我已经阅读了用于MultiLayerNetwork和INDArray的JavaDoc,但是似乎没有任何东西引起此问题。我在加载没有数据的数据时确实遇到了一些问题,不得不做一个丑陋的修改。要使其正常工作。

private void outputBitSet(MolecularProperties molecularProperties, PrintWriter writer) {
    StringBuilder builder = new StringBuilder();
    BitSet fingerprintBitSet = molecularProperties.bitSet;

    if (useStructuralFingerprint) {
        for (int i = 0; i < fingerprintBitSet.size(); i++) {
            double bit = fingerprintBitSet.get(i) ? VALUE2 : VALUE1;

            appendComma(builder);
            builder.append(bit);
        }
    }
    if (useMolecularProperties) {
        addProperties(builder, molecularProperties);
    }

    if (! action.equals(Action.RUN)) {
        if (isRegression) {
            log.debug(String.format("%8.6f", molecularProperties.regressionValue) + " " + molecularProperties.id);

            appendComma(builder);
            builder.append(String.format("%8.6f", molecularProperties.regressionValue));
        }
        else {
            appendComma(builder);
            builder.append(molecularProperties.classification);
        }
    }
    else { // TODO This is need to fix an issue with CSVRecordReader expecting there to be one or more regression values. 
        if (isRegression) {
            appendComma(builder);
            builder.append(String.format("%8.6f", VALUE1));
        }
        else {
            appendComma(builder);
            builder.append(CLASS1);
        }
    }
    writer.println(builder.toString());
}

private void outputBitSet(List<MolecularProperties> molecularPropertiesList, PrintWriter writer) {
    if (!action.equals(Action.RUN)) {
        Collections.shuffle(molecularPropertiesList);
    }
    molecularPropertiesList.forEach(m -> outputBitSet(m, writer));
}

我真的很感谢任何建议吗?

巴特

0 个答案:

没有答案