TensorFlow:来自Gumbel Softmax的样本整数

时间:2019-03-07 06:59:14

标签: python tensorflow tensorflow-probability

我正在实现一个从分类分布中采样整数的程序,其中每个整数都与概率相关联。我需要确保该程序是可微分的,以便可以应用反向传播。我发现java -jar application.jar --spring.profiles.active=prod --spring.config.location=c:\config 与我要达到的目标非常接近。

但是,此类的tf.contrib.distributions.RelaxedOneHotCategorical方法返回的是一个单向量,而不是整数。如何编写既可微且返回整数/标量而不是向量的程序?

3 个答案:

答案 0 :(得分:0)

RelaxedOneHotCategorical实际上可区分的原因与以下事实有关:它返回浮点数的softmax向量,而不是argmax int索引。如果只需要最大元素的索引,则不妨使用Categorical。

答案 1 :(得分:0)

您可以将松弛的一个热点向量与向量defaultViewType: "thumbnails", defaultViewType_Files: "compact" 进行点积运算。结果将为您提供所需的标量。

例如,如果您的一个热门向量是onInit: function(finder) { finder.on("folder:getFiles:before", function(event) { var folder = finder.request("folder:getActive"); var resource = folder.getPath({full: true}).replace(/:.*$/, ""); switch (resource) { case "Files": finder.request("files.changeView", "compact"); break; case "Images": finder.request("files.changeView", "thumbnails"); break; }); } ,那么files.changeView会给您4,这正是您想要的。

答案 2 :(得分:0)

您无法以可微分的方式获得您想要的东西,因为 argmax 不可微分,这就是首先创建 Gumbel-Softmax 分布的原因。例如,这允许您将语言模型的输出用作生成对抗网络中判别器的输入,因为随着温度的变化,激活会接近一个单热向量。

如果您只是需要在推理或测试时检索最大元素,则可以使用 tf.math.argmax。但是没有办法以可微分的方式做到这一点。