使用Pytorch训练神经网络时了解类的使用

时间:2019-09-08 13:29:10

标签: python neural-network pytorch

我目前正在阅读有关使用交叉熵损失训练pytorch documentation中的神经网络的信息。用于计算损失的标准称为

criterion = nn.CrossEntropyLoss()

根据torch.nn documentationCrossEntropyLoss是一个类。根据我的理解,这意味着criterionnn.CrossEntropyLoss类型的对象。

训练神经网络时,criterion用于通过以下方式计算损耗

loss = criterion(input, target)

这让我有些困惑。 如果 criterion 是对象,那么如何将其用作功能?我期望的是类似的东西

loss = criterion.calculate_loss(input, target)

其中calculate_loss()nn.CrossEntropyLoss类中定义的方法。此外,文档还使用以下代码行

running_loss += loss.item()

item()方法从何而来?我找不到在线提及item()的消息源。

1 个答案:

答案 0 :(得分:2)

  

如果条件是对象,那么如何将其用作功能?

在这种情况下,标准对象具有一个forward方法。 criterion(input, target)criterion.forward(input, target)的简写

  

这个item()方法从何而来?

此方法返回一维Tensor。可以使用item()作为数字访问单个值。