我正在努力使CNN回归网络平稳运行,但输出结果与实际情况相去甚远。输入是二进制图像(600x600)
,背景为黑色,前景为白色。与每个输入关联的地面真相是图像,其颜色范围为0 to 255
,并且在0 and 1.
在下面,您可以分别看到Input (x)
,ground truth(y)
和predicted output (y_hat)
的一个样本。该网络几乎可以检测到边缘和背景,但是在前景中,所有的预测值几乎都相同(请参见预测之一)。
我相信,由于输入中的所有前景都是1
,因此网络无法正确更新权重,因此可能需要自定义损失函数来惩罚损失函数。
我想知道我的想法是否正确,如果需要,我需要哪种自定义功能?
我试图通过下面的自定义损失函数对前景进行惩罚,但是并没有改善结果。
mse = nn.MSELoss(reduction=‘mean’)
def criterion(y, y_hat, x, loss, weight=0.1):
y_hat_modified = torch.where(x[:,0:1]==1, weight*y_hat,y_hat)
return loss(y,y_hat_modified)