如何使用AllenNLP设置完全禁用模型/权重序列化?

时间:2020-10-12 09:44:54

标签: allennlp

我希望通过使用jsonnet配置文件来禁用序列化标准AllenNLP模型训练中的所有模型/状态权重。

其原因是我正在使用Optuna运行自动超参数优化。 测试数十种模型会很快填充驱动器。 我已经通过将num_serialized_models_to_keep设置为0来禁用检查指针:

trainer +: {
    checkpointer +: {
        num_serialized_models_to_keep: 0,
    },

我不希望将serialization_dir设置为None,因为我仍然想要有关记录中间指标等的默认行为。我只想禁用默认模型状态,训练状态,并撰写最佳模型权重

除了我在上面设置的选项之外,是否还有任何默认的Trainer或checkpointer选项来禁用所有模型权重序列化?我检查了API文档和网页,但找不到任何内容。

如果我需要自己定义此类选项的功能,我应该在我的Model子类中覆盖AllenNLP的哪些基本功能?

或者,当训练结束时,它们是否可用于清除中间模型状态?

编辑: @petew's answer显示了自定义检查指针的解决方案,但是我不清楚如何在我的用例中使该代码可被allennlp train找到。< / p>

我希望可以从如下配置文件中调用custom_checkpointer:

trainer +: {
    checkpointer +: {
        type: empty,
    },

调用allennlp train --include-package <$my_package>时加载检查点的最佳实践是什么?

我有my_package,其子目录位于my_package/modelsmy_package/training等子目录中。 我想将自定义检查点代码放在my_package/training/custom_checkpointer.py中 我的主要模型位于my_package/models/main_model.py中。 我必须在main_model类中编辑或导入任何代码/函数才能使用自定义检查点吗?

1 个答案:

答案 0 :(得分:1)

您可以创建并注册一个自定义Checkpointer,而该自定义基本上什么都不做:

@Checkpointer.register("empty")
class EmptyCheckpointer(Registrable):
    def maybe_save_checkpoint(
        self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
    ) -> None:
        pass

    def save_checkpoint(
        self,
        epoch: Union[int, str],
        trainer: "allennlp.training.trainer.Trainer",
        is_best_so_far: bool = False,
        save_model_only=False,
    ) -> None:
        pass

    def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
        pass

    def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        return {}, {}

    def best_model_state(self) -> Dict[str, Any]:
        return {}