因此,我刚刚学习了Pytorch,他们说必须先通过.train()方法将NN置于训练模式,然后才能推断.eval()模式。我正在阅读本教程,根本没有.train()。为什么会这样?
答案 0 :(得分:2)
.train()
将模块的self.training
属性设置为True
。从source for nn.Module
中可以看出,此属性最初设置为True
。因此,除非您在开始训练之前已经致电eval()
,否则您无需致电train()
。但是无论如何这样做可能是一个好习惯。
此外,.train()
和eval()
仅影响某些模块(例如dropout和batchnorm)。因此,如果您不使用这些模块,则不必真正调用它们,但是再次这样做,可能是一个好习惯。