修复要保留的模型检查点数量

时间:2019-01-14 10:56:35

标签: tensorflow pytorch

在pytorch中的tf.Saver()中是否有类似于max_to_keep的东西?如何仅保存N个模型检查点?火炬中有内置功能吗?

我正在使用下面的功能来保存检查点。

import os
import torch

def save_checkpoint(model, save_path):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))

    torch.save(model.cpu().state_dict(), save_path)
    model.cuda()

0 个答案:

没有答案