我试图使用PyTorch框架重新实现TensorFlow代码。下面,我为目标(Batch, 9, 9, 4)
和目标为(Batch, 9, 9, 4)
TensorFlow实现:
loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
loss = tf.matrix_band_part(loss, 0, -1) - tf.matrix_band_part(loss, 0, 0)
PyTorch实现:
output = torch.tensor(output, requires_grad=True).view(-1, 4)
target = torch.tensor(target).view(-1, 4).argmax(1)
loss = torch.nn.CrossEntropyLoss(reduction='none')
my_loss = loss(output, target).view(-1,9,9)
对于PyTorch的实现,我不确定如何实现tf.matrix_band_part
。我当时正在考虑定义一个遮罩,但是我不确定这是否会损害反向传播。我知道torch.triu
,但是此功能不适用于2维以上的张量。