我的目标是建立一个自动编码网络,在该网络中,我可以训练身份功能,然后进行前向传递,从而重建输入。
为此,我尝试使用VariationalAutoencoder
,例如像这样:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(77147718)
.trainingWorkspaceMode(WorkspaceMode.NONE)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.list()
.layer(0, new VariationalAutoencoder.Builder()
.activation(Activation.LEAKYRELU)
.nIn(100).nOut(15)
.encoderLayerSizes(120, 60, 30)
.decoderLayerSizes(30, 60, 120)
.pzxActivationFunction(Activation.IDENTITY)
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction()))
.build())
.pretrain(true).backprop(false)
.build();
但是,VariationalAutoencoder
似乎设计用于训练(并提供)从输入到编码版本的映射,即在上述示例配置中大小为100的向量到大小为15的向量。
但是,我对编码版本并不特别感兴趣,但是想训练一个100向量到其自身的映射。然后,我想通过它运行其他100个向量,并获得它们的重构版本。
但是即使查看VariationalAutoencoder
(或AutoEncoder
)的API,我也无法弄清楚该如何做。还是这些层不是为这种“端到端使用”而设计的,而我将不得不手动构建一个自动编码网络?
答案 0 :(得分:1)
您可以看到如何使用VAE层提取平均重建量from the variational example。
有两种方法可以从可变层获得重建。标准为string errorString = "";
string outString = "";
protected void receiveError(object sender, DataReceivedEventArgs args)
{
if (args.Data != null)
{
Console.WriteLine("StdErr:" + args.Data);
errorString += args.Data + " / ";
}
else
Console.WriteLine("StdErr: null");
}
protected void receiveStd(object sender, DataReceivedEventArgs args)
{
if (args.Data != null)
{
Console.WriteLine("StdOut:" + args.Data);
outString += args.Data + "/";
}
else
Console.WriteLine("StdOut: null");
}
string pyPath = @"C:\Users\...\Python\Python36\python.exe";
Process myProcess = new Process();
myProcess.StartInfo.FileName = pyPath;
myProcess.StartInfo.UseShellExecute = false;
myProcess.StartInfo.RedirectStandardInput = true;
myProcess.StartInfo.RedirectStandardOutput = true;
myProcess.StartInfo.RedirectStandardError = true;
//myProcess.StartInfo.Verb = "runas";
myProcess.ErrorDataReceived += new DataReceivedEventHandler(receiveError);
myProcess.OutputDataReceived += new DataReceivedEventHandler(receiveStd);
myProcess.Start();
myProcess.BeginErrorReadLine();
myProcess.BeginOutputReadLine();
StreamWriter myStreamWriter = myProcess.StandardInput;
myStreamWriter.WriteLine("print('hello world')");
myStreamWriter.WriteLine("exit()");
myProcess.WaitForExit();
Console.WriteLine("out="+outString);
Console.WriteLine("err="+errorString);
myProcess.Close();
,它将从图层中抽取样本并提供平均值。如果需要原始样本,可以使用generateAtMeanGivenZ
。有关其他所有方法,请参见the javadoc page。