反向传播的损失爆炸

时间:2021-01-29 12:05:50

标签: pytorch

def weighted_mse_loss(input,target,weights):
    out = (input-target)**2
    #print('Simple-Ou-Shape::',out.shape,out)
    out = out * weights.expand_as(out)
    out=out.squeeze(0)
    loss = out.sum(0)
    loss=loss.sum() # or sum over whatever dimensions
    return loss

gen.train()
mse_loss=torch.nn.MSELoss()
for i,(hr,lr) in enumerate(train_dataloader):
  hr=hr.to(device)
  lr=lr.to(device)

  gen_optimizer.step()
  generated=gen(lr)

  #Target-Text-Regions
  
  with torch.no_grad():
    y,feature=craft_model(hr)  
    y_refiner = refine_net(y, feature)
    score_link = y_refiner[0,:,:,0].cpu().data.numpy()
    score_link=np.expand_dims(score_link,0)
  score_link=torch.from_numpy(score_link) 
  score_link=F.interpolate(score_link.unsqueeze(0),size=(hr[0][0].shape))
  region_map_loss=weighted_mse_loss(generated,hr,score_link)

  simple_mse_loss=mse_loss(generated,hr)
  print('Simple-MSE-Loss::',simple_mse_loss)
  print('RegionMapAware Loss::',region_map_loss)
  simple_mse_loss.backward(retain_graph=True)
  gen_optimizer.step()
  if i==10:
    break

我想将损失集中在文本部分,所以我使用了文本检测器模型(CRAFT),并试图掩盖该模型预测区域的损失。 但我的损失并没有减少而是呈指数增长

这是进度输出

Simple-MSE-Loss:: tensor(1.7920, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1162.5378, grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(30369398., grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1.6868e+10, grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(10484.7695, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(6352537., grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(1.6715e+30, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1.0599e+33, grad_fn=<SumBackward0>)
--------------------------------------------------------------------

在此先感谢您的帮助。

0 个答案:

没有答案