我已经从EncoderDecoderLSTM示例https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/encdec/
创建了一个网络NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(1).learningRate(LEARNING_RATE).rmsDecay(RMS_DECAY)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).miniBatch(true).updater(Updater.RMSPROP)
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer);
GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true).backpropType(BackpropType.Standard)
.tBPTTBackwardLength(TBPTT_SIZE).tBPTTForwardLength(TBPTT_SIZE);
graphBuilder.addInputs("inputLine", "decoderInput")
.setInputTypes(InputType.recurrent(dict.size()), InputType.recurrent(dict.size()))
.addLayer("embeddingEncoder", new EmbeddingLayer.Builder().nIn(dict.size()).nOut(EMBEDDING_WIDTH).build(), "inputLine")
.addLayer("encoder",
new GravesLSTM.Builder().nIn(EMBEDDING_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH).build(),
"embeddingEncoder")
.addVertex("thoughtVector", new LastTimeStepVertex("inputLine"), "encoder")
.addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "thoughtVector")
.addVertex("merge", new MergeVertex(), "decoderInput", "dup")
.addLayer("decoder",
new GravesLSTM.Builder().nIn(dict.size() + HIDDEN_LAYER_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH)
.build(),
"merge")
.addLayer("output", new RnnOutputLayer.Builder().nIn(HIDDEN_LAYER_WIDTH).nOut(dict.size()).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "decoder")
.setOutputs("output");
net = new ComputationGraph(graphBuilder.build());
net.init();
如何通过单一考试在线培训这个网络?
我认为它应该是这样的,但需要把“< go>”和“< eos>”载体
private void train(float[] input, float[] prediction){
float[] decodeArr = new float[dict.size()];
INDArray decode = Nd4j.create(decodeArr, new int[]{1, dict.size(), 1});
INDArray in = Nd4j.create(input);
INDArray pred = Nd4j.create(prediction);
INDArray predictionMask = Nd4j.ones(dict.size());
INDArray inputMask = Nd4j.ones(dict.size());
MultiDataSet data = new MultiDataSet(new INDArray[]{in, decode}, new INDArray[]{pred},
new INDArray[] { inputMask, predictionMask }, new INDArray[] { predictionMask });
net.fit(data);
}