我已经测试并训练了神经网络进行回归分析。当我使用新数据运行网络时,每个条目都获得相同的数字。我从分类系统改编而成,并且奏效了。
代码是:
private void writeINDArray(INDArray output, PrintWriter writer, Iterator<String> identifierIterator) {
int rows = output.rows();
int coluumns = output.columns();
for (int i = 0; i < rows; i++) {
INDArray row = output.getRow(i);
StringJoiner stringJoiner = new StringJoiner("\t");
for (int j = 0; j < coluumns; j++) {
stringJoiner.add(Float.toString(row.getFloat(j)));
}
if (identifierIterator.hasNext()) {
stringJoiner.add(identifierIterator.next());
}
else {
throw new RuntimeException("identifier list is empty!");
}
writer.println(stringJoiner.toString());
log.info(stringJoiner);
}
}
@Override
public void run(File neuralNetworkZipFile, File fingerPrintFile, List<String> identifiers) {
log.info(String.format("running %s on %s", neuralNetworkZipFile.getAbsolutePath(), fingerPrintFile.getAbsolutePath()));
Iterator<String> identifierIterator = identifiers.iterator();
runResultFile = new File("run_results_" + Utility.timeDate() + ".txt");
try (RecordReader recordReader = new CSVRecordReader(0, ','); PrintWriter writer = new PrintWriter(runResultFile)) {
recordReader.initialize(new FileSplit(fingerPrintFile));
DataSetIterator iterator = neuralNetworkSupporter.getDataSetIterator(recordReader);
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(neuralNetworkZipFile);
while (iterator.hasNext()) {
DataSet fingerPrint = iterator.next();
INDArray output = model.output(fingerPrint.getFeatures(), false);
writeINDArray(output, writer, identifierIterator);
}
}
catch (IOException | InterruptedException e) {
e.printStackTrace();
}
}
任何建议我做错了什么?我已经阅读了用于MultiLayerNetwork和INDArray的JavaDoc,但是似乎没有任何东西引起此问题。我在加载没有数据的数据时确实遇到了一些问题,不得不做一个丑陋的修改。要使其正常工作。
private void outputBitSet(MolecularProperties molecularProperties, PrintWriter writer) {
StringBuilder builder = new StringBuilder();
BitSet fingerprintBitSet = molecularProperties.bitSet;
if (useStructuralFingerprint) {
for (int i = 0; i < fingerprintBitSet.size(); i++) {
double bit = fingerprintBitSet.get(i) ? VALUE2 : VALUE1;
appendComma(builder);
builder.append(bit);
}
}
if (useMolecularProperties) {
addProperties(builder, molecularProperties);
}
if (! action.equals(Action.RUN)) {
if (isRegression) {
log.debug(String.format("%8.6f", molecularProperties.regressionValue) + " " + molecularProperties.id);
appendComma(builder);
builder.append(String.format("%8.6f", molecularProperties.regressionValue));
}
else {
appendComma(builder);
builder.append(molecularProperties.classification);
}
}
else { // TODO This is need to fix an issue with CSVRecordReader expecting there to be one or more regression values.
if (isRegression) {
appendComma(builder);
builder.append(String.format("%8.6f", VALUE1));
}
else {
appendComma(builder);
builder.append(CLASS1);
}
}
writer.println(builder.toString());
}
private void outputBitSet(List<MolecularProperties> molecularPropertiesList, PrintWriter writer) {
if (!action.equals(Action.RUN)) {
Collections.shuffle(molecularPropertiesList);
}
molecularPropertiesList.forEach(m -> outputBitSet(m, writer));
}
我真的很感谢任何建议吗?
巴特