错误:标签和preOutput必须具有相同的形状

时间:2020-04-14 07:05:40

标签: lstm deeplearning4j

我正在使用deeplearning4j,并且不断收到此错误。总的来说,我对dl4j和AI非常陌生。我试图根据给定的10个先前值来预测下一个值。我正在使用LSTM。我很确定我可能必须使用遮罩,但我一无所知,dl4j社区还很小。这非常繁琐。这是我的代码: //我之前对数据进行了归一化,范围在3500-6500之间

 DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize,
            1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.EQUAL_LENGTH);
 DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize,
            1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.EQUAL_LENGTH);

 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .weightInit(WeightInit.XAVIER)
            .updater(new Adam())
            .list()
            .layer(new LSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                .activation(Activation.RELU).nIn(10).nOut(1).build())
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

这是错误:

Caused by: java.lang.IllegalArgumentException: Labels and preOutput must have equal shapes: got shapes [5, 1, 1] vs [50, 1]

我的数据由单列时间序列CSV文件组成,每个文件包含10个序列。标签是代表下一个值的单个值。

这里有我的火车功能和trainLabels声明

SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(numLinesToSkip, delimiter);
        trainFeatures.initialize(new NumberedFileInputSplit(testAndTrainFeatures + "%d.csv",0,800));
        SequenceRecordReader trainLabels = new CSVSequenceRecordReader(numLinesToSkip, delimiter);
        trainLabels.initialize(new NumberedFileInputSplit(testAndTrainLabels + "%d.csv", 0, 800));

1 个答案:

答案 0 :(得分:0)

您面临的问题是,您期望配置意味着从LSTM获得最后一个时间步作为您的密集层的输入。而是将LSTM的所有输出合并并馈入密集层。

要解决此问题,您应该将LSTM层包装在LastTimestep(...)中,以便仅实际发生您期望的事情。

但是,您将遇到第二个问题,那就是将标签视为序列,因为这是加载它们的方式。为了使标签不连续,您必须像这样在数据集迭代器上应用预处理器:

trainData.setPreProcessor(new LabelLastTimeStepPreProcessor());

如果在应用这些更改后仍然遇到问题,请随时在社区论坛上提问,您可能会更快地得到答案。