在deeplearning4j

时间:2018-12-22 18:58:59

标签: machine-learning deep-learning deeplearning4j

我的网络有两个输入-一个是时间序列(循环),另一个是常规前馈。图构建器的以下代码部分应说明一切:

    final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder()
            .backpropType(BackpropType.TruncatedBPTT)
            .tBPTTBackwardLength(tbpttSize)
            .tBPTTForwardLength(tbpttSize)
            .addInputs("recurrentInput", "nonRecurrentInput")
            .setInputTypes(
                    InputType.recurrent(numFeaturesRecurrent),
                    InputType.feedForward(numFeaturesNonRecurrent))
            .addLayer("encoder",
                    new LSTM.Builder()
                            .nIn(numFeaturesRecurrent)
                            .nOut(hiddenRecurrentSize)
                            .activation(Activation.TANH)
                            .build(),
                    "recurrentInput")
            .addVertex("thoughtVector",
                    new LastTimeStepVertex("recurrentInput"), "encoder")
            .addVertex("merge",
                    new MergeVertex(), "thoughtVector", "nonRecurrentInput")
            ...

TruncatedBPTT配置参数应用于整个输入,并且出现以下错误:

java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 50 must be less than its size: 13
    at org.nd4j.linalg.indexing.NDArrayIndex.validate(NDArrayIndex.java:459)
    at org.nd4j.linalg.indexing.NDArrayIndex.resolve(NDArrayIndex.java:364)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4996)
    at org.deeplearning4j.nn.graph.ComputationGraph.getSubsetsForTbptt(ComputationGraph.java:3619)
    at org.deeplearning4j.nn.graph.ComputationGraph.doTruncatedBPTT(ComputationGraph.java:3568)
    at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1140)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1098)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1006)
    at org.mypackage.MultivariatePredictorNet.train(MultivariatePredictorNet.java:140)
    at org.mypackage.MultivariatePredictorNet.main(MultivariatePredictorNet.java:209)

13恰好是非经常性输入中的要素数量。那么,如何使TruncatedBPTT配置仅适用于循环输入?

0 个答案:

没有答案