CSVSequenceRecordReader是否正在创建用于训练LSTM网络的兼容数据集?

时间:2019-01-22 11:56:13

标签: deeplearning4j

我想训练一个简单的LSTM网络,但是我遇到了例外

mykey: {name: "stack"}

我正在训练一个具有单个LSTM单元和单个输出单元的简单NN进行回归。

我在csv文件中创建了一个仅包含10个具有可变序列长度(从5到10)的样本的训练数据集,每个样本仅由一个输入值和一个输出值组成。

我从java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]] 创建了一个SequenceRecordReaderDataSetIterator。 当我训练网络时,代码将引发异常。

我尝试生成直接使用'f shape'INDarray编码数据集迭代器的随机数据集,并且代码运行无误。

所以问题似乎是CSVSequenceRecordReader创建的张量的形状。

有人有这个问题吗?

SingleFileTimeSeriesDataDataer.java

CSVSequenceRecordReader

TestConfBuilder.java

package org.mmarini.lstmtest;

import java.io.IOException;

import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/**
 *
 */
public class SingleFileTimeSeriesDataReader {

    private final int miniBatchSize;
    private final int numPossibleLabels;
    private final boolean regression;
    private final String filePattern;
    private final int maxFileIdx;
    private final int minFileIdx;
    private final int numInputs;

    /**
     * 
     * @param filePattern
     * @param minFileIdx
     * @param maxFileIdx
     * @param numInputs
     * @param numPossibleLabels
     * @param miniBatchSize
     * @param regression
     */
    public SingleFileTimeSeriesDataReader(final String filePattern, final int minFileIdx, final int maxFileIdx,
            final int numInputs, final int numPossibleLabels, final int miniBatchSize, final boolean regression) {
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.filePattern = filePattern;
        this.maxFileIdx = maxFileIdx;
        this.minFileIdx = minFileIdx;
        this.numInputs = numInputs;
    }

    /**
     *
     * @return
     * @throws IOException
     * @throws InterruptedException
     */
    public DataSetIterator apply() throws IOException, InterruptedException {
        final SequenceRecordReader reader = new CSVSequenceRecordReader(0, ",");
        reader.initialize(new NumberedFileInputSplit(filePattern, minFileIdx, maxFileIdx));
        final DataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels,
                numInputs, regression);
        return iter;
    }
}

TestTrainingTest .java

/**
 *
 */
package org.mmarini.lstmtest;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

/**
 * @author mmarini
 *
 */
public class TestConfBuilder {

    private final int noInputUnits;
    private final int noOutputUnits;
    private final int noLstmUnits;

    /**
     *
     * @param noInputUnits
     * @param noOutputUnits
     * @param noLstmUnits
     */
    public TestConfBuilder(final int noInputUnits, final int noOutputUnits, final int noLstmUnits) {
        super();
        this.noInputUnits = noInputUnits;
        this.noOutputUnits = noOutputUnits;
        this.noLstmUnits = noLstmUnits;
    }

    /**
     *
     * @return
     */
    public MultiLayerConfiguration build() {
        final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        final LSTM lstmLayer = new LSTM.Builder().units(noLstmUnits).nIn(noInputUnits).activation(Activation.TANH)
                .build();
        final RnnOutputLayer outLayer = new RnnOutputLayer.Builder(LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
                .activation(Activation.IDENTITY).nOut(noOutputUnits).nIn(noLstmUnits).build();
        final MultiLayerConfiguration conf = builder.list(lstmLayer, outLayer).build();
        return conf;
    }
}

我希望不会引发任何异常,但是我得到了以下异常:

package org.mmarini.lstmtest;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;

import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

class TestTrainingTest {

    private static final int MINI_BATCH_SIZE = 10;
    private static final int NUM_LABELS = 1;
    private static final boolean REGRESSION = true;
    private static final String SAMPLES_FILE = "src/test/resources/datatest/sample_%d.csv";
    private static final int MIN_INPUTS_FILE_IDX = 0;
    private static final int MAX_INPUTS_FILE_IDX = 9;
    private static final int NUM_INPUTS_COLUMN = 1;
    private static final int NUM_HIDDEN_UNITS = 1;

    DataSetIterator createData() {
        final double[][][] featuresAry = new double[][][] { { { 0.5, 0.2, 0.5 } }, { { 0.5, 1.0, 0.0 } } };
        final double[] featuresData = ArrayUtil.flattenDoubleArray(featuresAry);
        final int[] featuresShape = new int[] { 2, 1, 3 };
        final INDArray features = Nd4j.create(featuresData, featuresShape, 'c');

        final double[][][] labelsAry = new double[][][] { { { 1.0, -1.0, 1.0 }, { 1.0, -1.0, -1.0 } } };
        final double[] labelsData = ArrayUtil.flattenDoubleArray(labelsAry);
        final int[] labelsShape = new int[] { 2, 1, 3 };
        final INDArray labels = Nd4j.create(labelsData, labelsShape, 'c');

        final INDArrayDataSetIterator iter = new INDArrayDataSetIterator(
                Arrays.asList(new Pair<INDArray, INDArray>(features, labels)), 2);
        System.out.println(iter.inputColumns());
        return iter;
    }

    private String file(String template) {
        return new File(".", template).getAbsolutePath();
    }

    @Test
    void testBuild() throws IOException, InterruptedException {
        final SingleFileTimeSeriesDataReader reader = new SingleFileTimeSeriesDataReader(file(SAMPLES_FILE),
                MIN_INPUTS_FILE_IDX, MAX_INPUTS_FILE_IDX, NUM_INPUTS_COLUMN, NUM_LABELS, MINI_BATCH_SIZE, REGRESSION);

        final DataSetIterator data = reader.apply();

        assertThat(data.inputColumns(), equalTo(NUM_INPUTS_COLUMN));
        assertThat(data.totalOutcomes(), equalTo(NUM_LABELS));

        final TestConfBuilder builder = new TestConfBuilder(NUM_INPUTS_COLUMN, NUM_LABELS, NUM_HIDDEN_UNITS);
        final MultiLayerConfiguration conf = builder.build();
        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
        assertNotNull(net);
        net.init();
        net.fit(data);
    }

}

1 个答案:

答案 0 :(得分:0)