如何批量计算指针网络的交叉熵?

时间:2018-03-26 04:47:58

标签: python-3.x tensorflow deep-learning softmax cross-entropy

在指针网络中,输出logits超过输入的长度。使用此类批次意味着将输入填充到批输入的最大长度。现在,这一切都很好,直到我们必须计算损失。目前我在做的是:

logits = stabilize(logits(inputs))     #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs)     #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)

现在我使用这些概率来计算交叉熵

cross_entropy = sum_over_batches(probs[correct_class])

我能做得比这更好吗?关于如何通过处理指针网络的人来完成任何想法?

如果我没有可变大小的输入,这一切都可以使用logits和标签上的可调用tf.nn.softmax_cross_entropy_with_logits来实现(这是高度优化的),但是变量长度会产生错误的结果,因为softmax计算的分母大于1输入中的每个填充。

1 个答案:

答案 0 :(得分:1)

你希望看到你的方法,并且据我所知,这也是如何在RNN细胞中实现的。请注意,1x = dx的导数和0x = 0的导数。这会产生您想要的结果,因为您要对网络末端的梯度求和/求平均值。

您唯一可以考虑的是根据屏蔽值的数量重新调整损失。您可能会注意到,当有0个屏蔽值时,您的渐变的幅度将与您使用许多蒙版值的幅度略有不同。我不清楚这会产生重大影响,但也许会产生很小的影响。

否则,我自己也使用了同样的技术取得了巨大的成功,所以我在这里说你正走在正确的轨道上。