相当于Tensorflow损失函数的PyTorch

时间:2019-12-05 03:36:19

标签: tensorflow pytorch

我试图使用PyTorch框架重新实现TensorFlow代码。下面,我为目标(Batch, 9, 9, 4)和目标为(Batch, 9, 9, 4)

的网络输出提供了TF示例代码和PyT解释。

TensorFlow实现:

loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
loss = tf.matrix_band_part(loss, 0, -1) - tf.matrix_band_part(loss, 0, 0)

PyTorch实现:

output = torch.tensor(output, requires_grad=True).view(-1, 4)
target = torch.tensor(target).view(-1, 4).argmax(1)

loss = torch.nn.CrossEntropyLoss(reduction='none')
my_loss = loss(output, target).view(-1,9,9)

对于PyTorch的实现,我不确定如何实现tf.matrix_band_part。我当时正在考虑定义一个遮罩,但是我不确定这是否会损害反向传播。我知道torch.triu,但是此功能不适用于2维以上的张量。

1 个答案:

答案 0 :(得分:1)

(至少)自1.2.0 torch.triu版以来,该版本可以很好地与per docs配合使用。

您可以通过einsumtorch.einsum('...ii->...i', A)来获取对角线元素。

使用口罩不会伤及背部防护。您可以将其视为投影(显然,这对于反向传播很有效)。