Detectron2-如何记录培训期间的验证损失?

时间:2020-08-26 23:51:52

标签: validation pytorch loss

我从mnslarcher复制了这个想法,并为我的关键点检测器(resnet50主干)算法编写了以下两个函数。

def build_valid_loader(cfg):
    _cfg = cfg.clone()
    _cfg.defrost()  # make this cfg mutable.
    _cfg.DATASETS.TRAIN = cfg.DATASETS.TEST
    return build_detection_train_loader(_cfg)

def store_valid_loss(model, data, storage):
    training_mode = model.training
    with torch.no_grad():
        loss_dict = model(data)
        losses = sum(loss_dict.values())
        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {k: v.item()
                                for k, v in comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        if comm.is_main_process():
            storage.put_scalars(val_loss=losses_reduced, **loss_dict_reduced)
    model.train(training_mode)

然后在plain_train_net.py中,我称它们为波纹管。

    val_data_loader = build_valid_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, val_data, iteration in zip(data_loader, val_data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            ..
            ..
           #At the end of the for loop.
           # Calculate and log validation loss.
            store_valid_loss(model, val_data, storage)

1k迭代后,loss_keypoint不断增加,但是与没有total_loss调用的情况相比,store_valid_loss相同。我想念什么?谁能帮忙了解一下? 我正在使用4 GeForce RTX 2080 Ti。

0 个答案:

没有答案