在pytorch中的tf.Saver()中是否有类似于max_to_keep的东西?如何仅保存N个模型检查点?火炬中有内置功能吗?
我正在使用下面的功能来保存检查点。
import os
import torch
def save_checkpoint(model, save_path):
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(model.cpu().state_dict(), save_path)
model.cuda()