从DL4J中的自动编码器读取重构的矢量

时间:2018-07-13 07:40:26

标签: autoencoder deeplearning4j dl4j

我的目标是建立一个自动编码网络,在该网络中,我可以训练身份功能,然后进行前向传递,从而重建输入。

为此,我尝试使用VariationalAutoencoder,例如像这样:

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(77147718)
                .trainingWorkspaceMode(WorkspaceMode.NONE)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
                .gradientNormalizationThreshold(1.0)
                .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
                .list()
                .layer(0, new VariationalAutoencoder.Builder()
                        .activation(Activation.LEAKYRELU)
                        .nIn(100).nOut(15)
                        .encoderLayerSizes(120, 60, 30)
                        .decoderLayerSizes(30, 60, 120)
                        .pzxActivationFunction(Activation.IDENTITY)
                        .reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction()))
                        .build())
                .pretrain(true).backprop(false)
                .build();

但是,VariationalAutoencoder似乎设计用于训练(并提供)从输入到编码版本的映射,即在上述示例配置中大小为100的向量到大小为15的向量。

但是,我对编码版本并不特别感兴趣,但是想训练一个100向量到其自身的映射。然后,我想通过它运行其他100个向量,并获得它们的重构版本。

但是即使查看VariationalAutoencoder(或AutoEncoder)的API,我也无法弄清楚该如何做。还是这些层不是为这种“端到端使用”而设计的,而我将不得不手动构建一个自动编码网络?

1 个答案:

答案 0 :(得分:1)

您可以看到如何使用VAE层提取平均重建量from the variational example

有两种方法可以从可变层获得重建。标准为string errorString = ""; string outString = ""; protected void receiveError(object sender, DataReceivedEventArgs args) { if (args.Data != null) { Console.WriteLine("StdErr:" + args.Data); errorString += args.Data + " / "; } else Console.WriteLine("StdErr: null"); } protected void receiveStd(object sender, DataReceivedEventArgs args) { if (args.Data != null) { Console.WriteLine("StdOut:" + args.Data); outString += args.Data + "/"; } else Console.WriteLine("StdOut: null"); } string pyPath = @"C:\Users\...\Python\Python36\python.exe"; Process myProcess = new Process(); myProcess.StartInfo.FileName = pyPath; myProcess.StartInfo.UseShellExecute = false; myProcess.StartInfo.RedirectStandardInput = true; myProcess.StartInfo.RedirectStandardOutput = true; myProcess.StartInfo.RedirectStandardError = true; //myProcess.StartInfo.Verb = "runas"; myProcess.ErrorDataReceived += new DataReceivedEventHandler(receiveError); myProcess.OutputDataReceived += new DataReceivedEventHandler(receiveStd); myProcess.Start(); myProcess.BeginErrorReadLine(); myProcess.BeginOutputReadLine(); StreamWriter myStreamWriter = myProcess.StandardInput; myStreamWriter.WriteLine("print('hello world')"); myStreamWriter.WriteLine("exit()"); myProcess.WaitForExit(); Console.WriteLine("out="+outString); Console.WriteLine("err="+errorString); myProcess.Close(); ,它将从图层中抽取样本并提供平均值。如果需要原始样本,可以使用generateAtMeanGivenZ。有关其他所有方法,请参见the javadoc page