为什么在此Pytorch官方教程中没有.train()方法?

时间:2019-07-06 21:47:52

标签: pytorch

因此,我刚刚学习了Pytorch,他们说必须先通过.train()方法将NN置于训练模式,然后才能推断.eval()模式。我正在阅读本教程,根本没有.train()。为什么会这样?

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

1 个答案:

答案 0 :(得分:2)

.train()将模块的self.training属性设置为True。从source for nn.Module中可以看出,此属性最初设置为True。因此,除非您在开始训练之前已经致电eval(),否则您无需致电train()。但是无论如何这样做可能是一个好习惯。

此外,.train()eval()仅影响某些模块(例如dropout和batchnorm)。因此,如果您不使用这些模块,则不必真正调用它们,但是再次这样做,可能是一个好习惯。