pytorch中的model.train()有什么作用?

时间:2018-07-20 00:10:00

标签: python pytorch

它在forward()中叫nn.Module吗?我以为,当我们调用模型时,正在使用forward方法。 为什么我们需要指定train()?

6 个答案:

答案 0 :(得分:18)

model.train()告诉模型您正在训练模型。因此,在培训和测试过程中表现不同的有效层(例如辍学,batchnorm等)可以知道发生了什么,因此可以相应地表现。

更多详细信息: 设置训练模式 (请参见source code)。您可以调用model.eval()或model.train(mode = False)来告诉您正在测试。 期望使用train函数来训练模型有些直观,但是并没有做到这一点。它只是设置模式。

答案 1 :(得分:4)

<头>
model.train() model.eval()
将您的模型设置为训练模式,即

BatchNorm 层使用每批统计数据
Dropout 层已激活etc


将您的模型设置为评估评估(推理)模式,即

BatchNorm 层使用运行统计数据
Dropout 层已停用等。

相当于model.train(False)

注意:这些函数调用都没有向前/向后传递。它们告诉模型如何在运行时采取行动。

这很重要,因为 some modules (layers)(例如 DropoutBatchNorm)被设计为在训练和推理期间表现不同,因此如果在错误的模式下运行,模型将产生意想不到的结果.

答案 2 :(得分:3)

这是module.train()的代码:

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这是module.eval

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

模式traineval是我们可以在其中设置模块的仅有的两种模式,它们是完全相反的。

那只是一个self.training标志,目前仅 dropoutbachnorm在意那个标志。

默认情况下,此标志设置为True

答案 3 :(得分:1)

有两种方法可以让模型知道您的意图,即您要训练模型还是要使用模型进行评估。 在使用model.train()的情况下,模型知道必须学习层,并且当我们使用model.eval()时,它指示模型没有新知识要学习,并且该模型用于测试。 model.eval()也是必需的,因为在pytorch中,如果我们使用batchnorm,而在测试过程中,如果我们只想传递单个图像,则如果未指定model.eval(),则pytorch将引发错误。

答案 4 :(得分:0)

当前的 official documentation 声明如下:

<块引用>

这仅对某些模块有任何 [sic] 影响。如果它们受到影响,请参阅特定模块的文档以了解其在培训/评估模式下的行为的详细信息,例如Dropout、BatchNorm 等

答案 5 :(得分:0)

考虑以下模型

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

此处,dropout 的功能因操作模式不同而不同。如您所见,它仅在 self.training==True 时有效。因此,当您键入 model.train() 时,模型的前向函数将执行 dropout,否则不会执行(例如在 model.eval()model.train(mode=False) 时)。