我目前正在阅读有关使用交叉熵损失训练pytorch documentation中的神经网络的信息。用于计算损失的标准称为
criterion = nn.CrossEntropyLoss()
根据torch.nn
documentation,CrossEntropyLoss
是一个类。根据我的理解,这意味着criterion
是nn.CrossEntropyLoss
类型的对象。
训练神经网络时,criterion
用于通过以下方式计算损耗
loss = criterion(input, target)
这让我有些困惑。 如果 criterion
是对象,那么如何将其用作功能?我期望的是类似的东西
loss = criterion.calculate_loss(input, target)
其中calculate_loss()
是nn.CrossEntropyLoss
类中定义的方法。此外,文档还使用以下代码行
running_loss += loss.item()
此item()
方法从何而来?我找不到在线提及item()
的消息源。
答案 0 :(得分:2)
如果条件是对象,那么如何将其用作功能?
在这种情况下,标准对象具有一个forward
方法。 criterion(input, target)
是criterion.forward(input, target)
的简写
这个item()方法从何而来?
此方法返回一维Tensor
。可以使用item()
作为数字访问单个值。