如何使用单个示例训练DeepLearning4j ComputationGraph

时间:2017-05-18 16:53:56

标签: java deeplearning4j

我已经从EncoderDecoderLSTM示例https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/encdec/

创建了一个网络
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
    builder.iterations(1).learningRate(LEARNING_RATE).rmsDecay(RMS_DECAY)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).miniBatch(true).updater(Updater.RMSPROP)
            .weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer);

    GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true).backpropType(BackpropType.Standard)
            .tBPTTBackwardLength(TBPTT_SIZE).tBPTTForwardLength(TBPTT_SIZE);
    graphBuilder.addInputs("inputLine", "decoderInput")
            .setInputTypes(InputType.recurrent(dict.size()), InputType.recurrent(dict.size()))
            .addLayer("embeddingEncoder", new EmbeddingLayer.Builder().nIn(dict.size()).nOut(EMBEDDING_WIDTH).build(), "inputLine")
            .addLayer("encoder",
                    new GravesLSTM.Builder().nIn(EMBEDDING_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH).build(),
                    "embeddingEncoder")
            .addVertex("thoughtVector", new LastTimeStepVertex("inputLine"), "encoder")
            .addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "thoughtVector")
            .addVertex("merge", new MergeVertex(), "decoderInput", "dup")
            .addLayer("decoder",
                    new GravesLSTM.Builder().nIn(dict.size() + HIDDEN_LAYER_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH)
                            .build(),
                    "merge")
            .addLayer("output", new RnnOutputLayer.Builder().nIn(HIDDEN_LAYER_WIDTH).nOut(dict.size()).activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "decoder")
            .setOutputs("output");

    net = new ComputationGraph(graphBuilder.build());
    net.init();

如何通过单一考试在线培训这个网络?

我认为它应该是这样的,但需要把“< go>”和“< eos>”载体

private void train(float[] input, float[] prediction){
    float[] decodeArr = new float[dict.size()];
    INDArray decode = Nd4j.create(decodeArr, new int[]{1, dict.size(), 1});
    INDArray in = Nd4j.create(input);
    INDArray pred = Nd4j.create(prediction);
    INDArray predictionMask = Nd4j.ones(dict.size());
    INDArray inputMask = Nd4j.ones(dict.size());
    MultiDataSet data = new MultiDataSet(new INDArray[]{in, decode}, new INDArray[]{pred},
        new INDArray[] { inputMask, predictionMask }, new INDArray[] { predictionMask });
    net.fit(data);
}

0 个答案:

没有答案