我希望通过使用jsonnet
配置文件来禁用序列化标准AllenNLP模型训练中的所有模型/状态权重。
其原因是我正在使用Optuna运行自动超参数优化。
测试数十种模型会很快填充驱动器。
我已经通过将num_serialized_models_to_keep
设置为0
来禁用检查指针:
trainer +: {
checkpointer +: {
num_serialized_models_to_keep: 0,
},
我不希望将serialization_dir
设置为None
,因为我仍然想要有关记录中间指标等的默认行为。我只想禁用默认模型状态,训练状态,并撰写最佳模型权重。
除了我在上面设置的选项之外,是否还有任何默认的Trainer或checkpointer选项来禁用所有模型权重序列化?我检查了API文档和网页,但找不到任何内容。
如果我需要自己定义此类选项的功能,我应该在我的Model子类中覆盖AllenNLP的哪些基本功能?
或者,当训练结束时,它们是否可用于清除中间模型状态?
编辑: @petew's answer显示了自定义检查指针的解决方案,但是我不清楚如何在我的用例中使该代码可被allennlp train
找到。< / p>
我希望可以从如下配置文件中调用custom_checkpointer:
trainer +: {
checkpointer +: {
type: empty,
},
调用allennlp train --include-package <$my_package>
时加载检查点的最佳实践是什么?
我有my_package,其子目录位于my_package/models
和my_package/training
等子目录中。
我想将自定义检查点代码放在my_package/training/custom_checkpointer.py
中
我的主要模型位于my_package/models/main_model.py
中。
我必须在main_model类中编辑或导入任何代码/函数才能使用自定义检查点吗?
答案 0 :(得分:1)
您可以创建并注册一个自定义Checkpointer,而该自定义基本上什么都不做:
@Checkpointer.register("empty")
class EmptyCheckpointer(Registrable):
def maybe_save_checkpoint(
self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
) -> None:
pass
def save_checkpoint(
self,
epoch: Union[int, str],
trainer: "allennlp.training.trainer.Trainer",
is_best_so_far: bool = False,
save_model_only=False,
) -> None:
pass
def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
pass
def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return {}, {}
def best_model_state(self) -> Dict[str, Any]:
return {}