Pytorch:训练期间的中级测试

时间:2018-01-12 18:52:20

标签: pytorch

如何在培训期间测试验证数据上的pytorch模型? 我知道有一个函数myNet.eval()可以切换任何辍学图层,但它是否也可以防止渐变累积?
另外,如何撤消myNet.eval()命令以继续训练?

如果有人有一些代码片段/玩具示例,我将不胜感激!

2 个答案:

答案 0 :(得分:2)

  

如何在培训期间测试验证数据上的pytorch模型?

有很多例子说明训练期间每个时期都有训练和测试步骤。一个简单的就是official MNIST example。由于pytorch不提供任何高级培训,验证或评分框架,您必须自己编写。通常这包括

  • 数据加载器(通常基于torch.utils.dataloader.Dataloader
  • 一个超过纪元总数的主循环
  • 使用培训数据优化模型的train()函数
  • test()valid()函数,用于根据验证数据和指标衡量模型的有效性

这也是您在链接示例中可以找到的内容。

或者,您可以使用提供基本循环和验证工具的框架,这样您就不必一直自己实现所有内容。

  • tnt是火炬的火炬手,为您提供不同的指标(如准确性)和火车循环的抽象。见this MNIST example
  • infernotorchsample尝试对与Keras非常相似的内容进行建模并提供一些验证工具
  • skorch是pytorch的scikit-learn包装器,可让您使用sklearn中的所有工具和指标
  

另外,如何撤消myNet.eval()命令以继续培训?

myNet.train()或者,提供一个布尔值来在eval和training之间切换:myNet.train(True)用于列车模式。

答案 1 :(得分:1)

  

我知道有一个函数myNet.eval()可以切换任何丢失层,但是它是否也阻止了渐变的积累?

它不会阻止渐变累积。

但我认为在测试期间,你想要忽略渐变。在这种情况下,您应该将输入的变量标记为volatile=True,这样可以节省一些用于正向计算的时间和空间。

  

另外,如何撤消myNet.eval()命令以继续训练?

myNet.train()