如何在PyTorch中保存模型?

时间:2020-07-08 13:54:50

标签: python-3.x pytorch

假设我们有一个可以验证like this的模型,并且该模型继承了torch.nn.Module

def validate(logger, config, valid_loader, model, criterion, epoch, main_proc):
    meters = AverageMeterGroup()
    model.eval()

    with torch.no_grad():
        for step, (x, y) in enumerate(valid_loader):
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            logits, _ = model(x)
            loss = criterion(logits, y)
            prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
            metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
            metrics = utils.reduce_metrics(metrics, config.distributed)
            meters.update(metrics)

            if main_proc and (step % config.log_frequency == 0 or step + 1 == len(valid_loader)):
                logger.info("Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1, config.epochs, step + 1, len(valid_loader), meters)

    if main_proc:
        logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg)
    return meters.prec1.avg, meters.prec5.avg

如何更改验证过程,以便将验证后的模型保存到文件系统中,以便可以在其他数据上运行?

1 个答案:

答案 0 :(得分:1)

with torch.no_grad():
    for step, (x, y) in enumerate(valid_loader):
        x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
        logits, _ = model(x)
        loss = criterion(logits, y)
        prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
        metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
        metrics = utils.reduce_metrics(metrics, config.distributed)
        meters.update(metrics)

        if main_proc and (step % config.log_frequency == 0 or step + 1 == len(valid_loader)):
            logger.info("Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1, config.epochs, step + 1, len(valid_loader), meters)

torch.save(model,'model'+str(epoch)+'.pt')
if main_proc:
    logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg)
return meters.prec1.avg, meters.prec5.avg