我想使用pytorch-lightning将模型权重保存到mlflow跟踪中。 pytorch-lightning支持logging。 但是,似乎不支持将模型权重另存为mlflow上的工件。
起初,我计划重写ModelCheckpoint类来实现此目的,但是由于复杂的Mixin操作,我发现这很困难。
有人知道实现此目标的简单方法吗?
答案 0 :(得分:0)
答案 1 :(得分:0)
正如@xela所说,您可以使用mlflow记录器的experiment
对象记录工件。
如果您想在训练期间频繁记录模型权重,可以扩展ModelCheckpoint:
class MLFlowModelCheckpoint(ModelCheckpoint):
def __init__(self, mlflow_logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mlflow_logger = mlflow_logger
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
run_id = self.mlflow_logger.run_id
self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path)
然后在您的训练代码中使用
mlflow_logger = MLFlowLogger()
checkpoint_callback = MLFlowModelCheckpoint(mlflow_logger)
trainer = pl.Trainer(checkpoint_callback=checkpoint_callback, logger=mlflow_logger)