我的网络有两个输入-一个是时间序列(循环),另一个是常规前馈。图构建器的以下代码部分应说明一切:
final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder()
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTBackwardLength(tbpttSize)
.tBPTTForwardLength(tbpttSize)
.addInputs("recurrentInput", "nonRecurrentInput")
.setInputTypes(
InputType.recurrent(numFeaturesRecurrent),
InputType.feedForward(numFeaturesNonRecurrent))
.addLayer("encoder",
new LSTM.Builder()
.nIn(numFeaturesRecurrent)
.nOut(hiddenRecurrentSize)
.activation(Activation.TANH)
.build(),
"recurrentInput")
.addVertex("thoughtVector",
new LastTimeStepVertex("recurrentInput"), "encoder")
.addVertex("merge",
new MergeVertex(), "thoughtVector", "nonRecurrentInput")
...
TruncatedBPTT配置参数应用于整个输入,并且出现以下错误:
java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 50 must be less than its size: 13
at org.nd4j.linalg.indexing.NDArrayIndex.validate(NDArrayIndex.java:459)
at org.nd4j.linalg.indexing.NDArrayIndex.resolve(NDArrayIndex.java:364)
at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4996)
at org.deeplearning4j.nn.graph.ComputationGraph.getSubsetsForTbptt(ComputationGraph.java:3619)
at org.deeplearning4j.nn.graph.ComputationGraph.doTruncatedBPTT(ComputationGraph.java:3568)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1140)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1098)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1006)
at org.mypackage.MultivariatePredictorNet.train(MultivariatePredictorNet.java:140)
at org.mypackage.MultivariatePredictorNet.main(MultivariatePredictorNet.java:209)
13恰好是非经常性输入中的要素数量。那么,如何使TruncatedBPTT配置仅适用于循环输入?