火炬中的自定义交叉熵损失

时间:2019-06-19 09:47:17

标签: machine-learning artificial-intelligence pytorch cross-entropy

我已经完成了pytorch交叉熵损失函数的自定义实现(因为我需要稍后介绍的更多灵活性)。我打算以此训练的模型将需要大量的时间来训练,并且可用的资源不能仅用于测试功能是否正确实现。我已经实现了矢量化实现,因为它将更快地运行。

以下是我的相同代码:

def custom_cross(my_pred,true,batch_size=BATCH_SIZE):
    loss= -torch.mean(torch.sum(true.view(batch_size, -1) * torch.log(my_pred.view(batch_size, -1)), dim=1))
    return loss

如果您可以建议对它进行更优化的实现,或者在当前操作中出现错误,我将不胜感激。该模型将使用Nvidia Tesla K-80进行训练。

1 个答案:

答案 0 :(得分:2)

如果您只需要cross entropy,则可以利用PyTorch定义的优势。

import torch.nn.functional as F
loss_func = F.cross_entropy
  

建议更优化的实现

PyTorch具有F.丢失函数,但是您可以使用纯Python轻松编写自己的函数。 PyTorch将自动为您的功能创建快速的GPU或矢量化的CPU代码。

因此,您可以检查PyTorch的原始实现,但我认为是这样的:

def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)

here是交叉熵损失的原始实现,现在您可以更改:

nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

满足您的需求,便拥有了。