它在forward()
中叫nn.Module
吗?我以为,当我们调用模型时,正在使用forward
方法。
为什么我们需要指定train()?
答案 0 :(得分:18)
model.train()
告诉模型您正在训练模型。因此,在培训和测试过程中表现不同的有效层(例如辍学,batchnorm等)可以知道发生了什么,因此可以相应地表现。
更多详细信息:
设置训练模式
(请参见source code)。您可以调用model.eval()或model.train(mode = False)来告诉您正在测试。
期望使用train
函数来训练模型有些直观,但是并没有做到这一点。它只是设置模式。
答案 1 :(得分:4)
model.train() |
model.eval() |
---|---|
将您的模型设置为训练模式,即 • BatchNorm 层使用每批统计数据• Dropout 层已激活etc |
将您的模型设置为评估评估(推理)模式,即 • BatchNorm 层使用运行统计数据• Dropout 层已停用等。相当于 model.train(False) 。 |
注意:这些函数调用都没有向前/向后传递。它们告诉模型如何在运行时采取行动。
这很重要,因为 some modules (layers)(例如 Dropout
、BatchNorm
)被设计为在训练和推理期间表现不同,因此如果在错误的模式下运行,模型将产生意想不到的结果.
答案 2 :(得分:3)
这是module.train()
的代码:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
这是module.eval
。
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
模式train
和eval
是我们可以在其中设置模块的仅有的两种模式,它们是完全相反的。
那只是一个self.training
标志,目前仅 dropout和bachnorm在意那个标志。
默认情况下,此标志设置为True
。
答案 3 :(得分:1)
有两种方法可以让模型知道您的意图,即您要训练模型还是要使用模型进行评估。 在使用model.train()的情况下,模型知道必须学习层,并且当我们使用model.eval()时,它指示模型没有新知识要学习,并且该模型用于测试。 model.eval()也是必需的,因为在pytorch中,如果我们使用batchnorm,而在测试过程中,如果我们只想传递单个图像,则如果未指定model.eval(),则pytorch将引发错误。
答案 4 :(得分:0)
当前的 official documentation 声明如下:
<块引用>这仅对某些模块有任何 [sic] 影响。如果它们受到影响,请参阅特定模块的文档以了解其在培训/评估模式下的行为的详细信息,例如Dropout、BatchNorm 等
答案 5 :(得分:0)
考虑以下模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
此处,dropout
的功能因操作模式不同而不同。如您所见,它仅在 self.training==True
时有效。因此,当您键入 model.train()
时,模型的前向函数将执行 dropout,否则不会执行(例如在 model.eval()
或 model.train(mode=False)
时)。