如何使用pytorch-lightning将模型权重保存到mlflow跟踪服务器?

时间:2019-12-03 03:20:45

标签: pytorch

我想使用pytorch-lightning将模型权重保存到mlflow跟踪中。 pytorch-lightning支持logging。 但是,似乎不支持将模型权重另存为mlflow上的工件。

起初,我计划重写ModelCheckpoint类来实现此目的,但是由于复杂的Mixin操作,我发现这很困难。

有人知道实现此目标的简单方法吗?

2 个答案:

答案 0 :(得分:0)

答案 1 :(得分:0)

正如@xela所说,您可以使用mlflow记录器的experiment对象记录工件。

如果您想在训练期间频繁记录模型权重,可以扩展ModelCheckpoint

class MLFlowModelCheckpoint(ModelCheckpoint):
    def __init__(self, mlflow_logger, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mlflow_logger = mlflow_logger

    @rank_zero_only
    def on_validation_end(self, trainer, pl_module):
        super().on_validation_end(trainer, pl_module)
        run_id = self.mlflow_logger.run_id
        self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path)

然后在您的训练代码中使用

mlflow_logger = MLFlowLogger()
checkpoint_callback = MLFlowModelCheckpoint(mlflow_logger)
trainer = pl.Trainer(checkpoint_callback=checkpoint_callback, logger=mlflow_logger)