我想训练一个简单的LSTM网络,但是我遇到了例外
mykey: {name: "stack"}
我正在训练一个具有单个LSTM单元和单个输出单元的简单NN进行回归。
我在csv文件中创建了一个仅包含10个具有可变序列长度(从5到10)的样本的训练数据集,每个样本仅由一个输入值和一个输出值组成。
我从java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]]
创建了一个SequenceRecordReaderDataSetIterator
。
当我训练网络时,代码将引发异常。
我尝试生成直接使用'f shape'INDarray编码数据集迭代器的随机数据集,并且代码运行无误。
所以问题似乎是CSVSequenceRecordReader
创建的张量的形状。
有人有这个问题吗?
CSVSequenceRecordReader
package org.mmarini.lstmtest;
import java.io.IOException;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
/**
*
*/
public class SingleFileTimeSeriesDataReader {
private final int miniBatchSize;
private final int numPossibleLabels;
private final boolean regression;
private final String filePattern;
private final int maxFileIdx;
private final int minFileIdx;
private final int numInputs;
/**
*
* @param filePattern
* @param minFileIdx
* @param maxFileIdx
* @param numInputs
* @param numPossibleLabels
* @param miniBatchSize
* @param regression
*/
public SingleFileTimeSeriesDataReader(final String filePattern, final int minFileIdx, final int maxFileIdx,
final int numInputs, final int numPossibleLabels, final int miniBatchSize, final boolean regression) {
this.miniBatchSize = miniBatchSize;
this.numPossibleLabels = numPossibleLabels;
this.regression = regression;
this.filePattern = filePattern;
this.maxFileIdx = maxFileIdx;
this.minFileIdx = minFileIdx;
this.numInputs = numInputs;
}
/**
*
* @return
* @throws IOException
* @throws InterruptedException
*/
public DataSetIterator apply() throws IOException, InterruptedException {
final SequenceRecordReader reader = new CSVSequenceRecordReader(0, ",");
reader.initialize(new NumberedFileInputSplit(filePattern, minFileIdx, maxFileIdx));
final DataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels,
numInputs, regression);
return iter;
}
}
/**
*
*/
package org.mmarini.lstmtest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
/**
* @author mmarini
*
*/
public class TestConfBuilder {
private final int noInputUnits;
private final int noOutputUnits;
private final int noLstmUnits;
/**
*
* @param noInputUnits
* @param noOutputUnits
* @param noLstmUnits
*/
public TestConfBuilder(final int noInputUnits, final int noOutputUnits, final int noLstmUnits) {
super();
this.noInputUnits = noInputUnits;
this.noOutputUnits = noOutputUnits;
this.noLstmUnits = noLstmUnits;
}
/**
*
* @return
*/
public MultiLayerConfiguration build() {
final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
final LSTM lstmLayer = new LSTM.Builder().units(noLstmUnits).nIn(noInputUnits).activation(Activation.TANH)
.build();
final RnnOutputLayer outLayer = new RnnOutputLayer.Builder(LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
.activation(Activation.IDENTITY).nOut(noOutputUnits).nIn(noLstmUnits).build();
final MultiLayerConfiguration conf = builder.list(lstmLayer, outLayer).build();
return conf;
}
}
我希望不会引发任何异常,但是我得到了以下异常:
package org.mmarini.lstmtest;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
class TestTrainingTest {
private static final int MINI_BATCH_SIZE = 10;
private static final int NUM_LABELS = 1;
private static final boolean REGRESSION = true;
private static final String SAMPLES_FILE = "src/test/resources/datatest/sample_%d.csv";
private static final int MIN_INPUTS_FILE_IDX = 0;
private static final int MAX_INPUTS_FILE_IDX = 9;
private static final int NUM_INPUTS_COLUMN = 1;
private static final int NUM_HIDDEN_UNITS = 1;
DataSetIterator createData() {
final double[][][] featuresAry = new double[][][] { { { 0.5, 0.2, 0.5 } }, { { 0.5, 1.0, 0.0 } } };
final double[] featuresData = ArrayUtil.flattenDoubleArray(featuresAry);
final int[] featuresShape = new int[] { 2, 1, 3 };
final INDArray features = Nd4j.create(featuresData, featuresShape, 'c');
final double[][][] labelsAry = new double[][][] { { { 1.0, -1.0, 1.0 }, { 1.0, -1.0, -1.0 } } };
final double[] labelsData = ArrayUtil.flattenDoubleArray(labelsAry);
final int[] labelsShape = new int[] { 2, 1, 3 };
final INDArray labels = Nd4j.create(labelsData, labelsShape, 'c');
final INDArrayDataSetIterator iter = new INDArrayDataSetIterator(
Arrays.asList(new Pair<INDArray, INDArray>(features, labels)), 2);
System.out.println(iter.inputColumns());
return iter;
}
private String file(String template) {
return new File(".", template).getAbsolutePath();
}
@Test
void testBuild() throws IOException, InterruptedException {
final SingleFileTimeSeriesDataReader reader = new SingleFileTimeSeriesDataReader(file(SAMPLES_FILE),
MIN_INPUTS_FILE_IDX, MAX_INPUTS_FILE_IDX, NUM_INPUTS_COLUMN, NUM_LABELS, MINI_BATCH_SIZE, REGRESSION);
final DataSetIterator data = reader.apply();
assertThat(data.inputColumns(), equalTo(NUM_INPUTS_COLUMN));
assertThat(data.totalOutcomes(), equalTo(NUM_LABELS));
final TestConfBuilder builder = new TestConfBuilder(NUM_INPUTS_COLUMN, NUM_LABELS, NUM_HIDDEN_UNITS);
final MultiLayerConfiguration conf = builder.build();
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
assertNotNull(net);
net.init();
net.fit(data);
}
}