重新训练模型的损失要比上一次会话结束时的损失高

时间:2020-08-23 15:14:40

标签: python machine-learning deep-learning pytorch

我正在训练去噪任务的模型。

每次我训练它时,新训练中的第一次损失都比我之前完成的损失高得多。例如在第一次训练后:

enter image description here

然后,当我想再训练100个时代时:

enter image description here

20x损失更高。

为什么会发生?

编辑: 这是绘制的损耗。您可以清楚地看到,在2500年之后,损失大幅度增加:

enter image description here

在此处添加了训练循环:

num_epochs = 500

criterion = torch.nn.MSELoss()

sgd_params = {
    "lr": 0.30,
    "momentum": 0.9
}

lin_optimizer = torch.optim.SGD(lin_model.parameters(), lr=0.30, momentum=0.9)
sta_optimizer = torch.optim.SGD(sta_model.parameters(), lr=0.30, momentum=0.9)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=0.30, momentum=0.9)

_, val_clean, val_noisy = util.lincomb_generate_data(batch_size*10, B, K, functions, sample_offset=sample_offset, noise_type="gaussian", noise_mean=noise_mean, noise_std=noise_std)

print("STARTED TRAINING")

for epoch in range(num_epochs):
    print("Current epoch {}\r".format(epoch), end="")
    # generate data returns the x-axis used for plotting as well as the clean and noisy data
    _, t_clean, t_noisy = util.lincomb_generate_data(batch_size, B, K, functions, sample_offset=sample_offset, noise_type="gaussian", noise_mean=noise_mean, noise_std=noise_std)

    # ===================forward=====================
    lin_output = lin_model(t_noisy.float())
    sta_output = sta_model(t_noisy.float())
    ref_output = ref_model(t_noisy.float())

    lin_loss = criterion(lin_output.float(), t_clean.float())
    sta_loss = criterion(sta_output.float(), t_clean.float())
    ref_loss = criterion(ref_output.float(), t_clean.float())
    
    lin_loss_list.append(lin_loss.data)
    sta_loss_list.append(sta_loss.data)
    ref_loss_list.append(ref_loss.data)
    # ===================backward====================
    lin_optimizer.zero_grad()
    sta_optimizer.zero_grad()
    ref_optimizer.zero_grad()

    lin_loss.backward()
    sta_loss.backward()
    ref_loss.backward()

    lin_optimizer.step()
    sta_optimizer.step()
    ref_optimizer.step()
    
    val_lin_loss = F.mse_loss(lin_model(val_noisy.float()), val_clean.float())
    val_sta_loss = F.mse_loss(sta_model(val_noisy.float()), val_clean.float())
    val_ref_loss = F.mse_loss(ref_model(val_noisy.float()), val_clean.float())

print("DONE TRAINING")

编辑2:碰碰

0 个答案:

没有答案