在教程中编写的Pytorch model.train()和separte train()函数

时间:2019-06-25 16:32:09

标签: python machine-learning pytorch

我是PyTorch的新手,我想知道您是否可以向我解释PyTorch中的默认model.train()与此处的train()函数之间的一些关键区别。

PyTorch官方文本分类教程中的另一个train()函数,对于在训练结束时是否存储模型权重感到困惑。

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

true

这是功能。然后以以下形式多次调用此函数:

let arr = [ 
  {name: "abc", show: true , display:"ABC"},
  {name: "xyz", show: false , display:"XYZ"},
  {name: "pqr", show: true , display:"PQR"},
  {name: "lmn", show: false , display:"LMN"}
];
          
let [arr2, arr1] = arr.reduce(
        (acc, {name, show}) => (acc[+show].push(name), acc), 
    [[], []]);

console.log(arr1); // ["abc", "pqr"]
console.log(arr2); // ["xyz", "lmn"]

在我看来,当这样编写时,似乎并不保存或更新模型权重,而是在每次迭代时都将其覆盖。它是否正确?还是该模型似乎正在正确训练?

另外,如果我要使用默认函数model.train(),主要的优点是什么,model.train()会执行与上面的train()函数大致相同的功能吗?

1 个答案:

答案 0 :(得分:1)

根据源代码heremodel.train()将模块设置为训练模式。因此,它基本上告诉您的模型您正在训练模型。这仅对某些模块(例如dropoutbatchnorm等)有效,它们在训练/评估模式下的行为会有所不同。对于model.train(),模型知道必须学习层次。

您可以调用model.eval()model.train(mode=False)来告诉模型它没有什么新知识,并且该模型用于测试目的。

model.train()仅设置模式。它实际上并没有训练模型。

train()实际上是在训练模型,即计算梯度并进行反向传播以学习权重。

从pytorch官方论坛here了解有关model.train()的更多信息。