def weighted_mse_loss(input,target,weights):
out = (input-target)**2
#print('Simple-Ou-Shape::',out.shape,out)
out = out * weights.expand_as(out)
out=out.squeeze(0)
loss = out.sum(0)
loss=loss.sum() # or sum over whatever dimensions
return loss
gen.train()
mse_loss=torch.nn.MSELoss()
for i,(hr,lr) in enumerate(train_dataloader):
hr=hr.to(device)
lr=lr.to(device)
gen_optimizer.step()
generated=gen(lr)
#Target-Text-Regions
with torch.no_grad():
y,feature=craft_model(hr)
y_refiner = refine_net(y, feature)
score_link = y_refiner[0,:,:,0].cpu().data.numpy()
score_link=np.expand_dims(score_link,0)
score_link=torch.from_numpy(score_link)
score_link=F.interpolate(score_link.unsqueeze(0),size=(hr[0][0].shape))
region_map_loss=weighted_mse_loss(generated,hr,score_link)
simple_mse_loss=mse_loss(generated,hr)
print('Simple-MSE-Loss::',simple_mse_loss)
print('RegionMapAware Loss::',region_map_loss)
simple_mse_loss.backward(retain_graph=True)
gen_optimizer.step()
if i==10:
break
我想将损失集中在文本部分,所以我使用了文本检测器模型(CRAFT),并试图掩盖该模型预测区域的损失。 但我的损失并没有减少而是呈指数增长
这是进度输出
Simple-MSE-Loss:: tensor(1.7920, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1162.5378, grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(30369398., grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1.6868e+10, grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(10484.7695, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(6352537., grad_fn=<SumBackward0>)
Simple-MSE-Loss:: tensor(1.6715e+30, grad_fn=<MseLossBackward>)
RegionMapAware Loss:: tensor(1.0599e+33, grad_fn=<SumBackward0>)
--------------------------------------------------------------------
在此先感谢您的帮助。