我想修改以下keras均方误差丢失(MSE),以便仅以稀疏方式计算损失。
def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)
我的输出y
是一个3通道图像,其中第3个通道仅在那些要计算损耗的像素处为非零。知道如何修改以上内容来计算稀疏损失?
答案 0 :(得分:5)
这不是您正在寻找的确切损失,但我希望它能为您提供编写功能的提示:
def masked_mse(mask_value):
def f(y_true, y_pred):
mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
masked_squared_error = K.square(mask_true * (y_true - y_pred))
masked_mse = (K.sum(masked_squared_error, axis=-1) /
K.sum(mask_true, axis=-1))
return masked_mse
f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
return f
该函数计算预测输出的所有值的MSE损失,除了那些在真输出中的对应值等于掩蔽值(例如-1)的元素。
两个说明:
- 计算均值时,分母必须是非掩盖值的计数,而不是
数组的维度,这就是为什么我没有使用K.mean(masked_squared_error, axis=1)
而我是' m
而是手动平均 。
- 屏蔽值必须是有效数字(即np.nan
或np.inf
将无法完成工作),这意味着您必须调整数据以使其不包含{ {1}}。
在此示例中,目标输出始终为mask_value
,但某些预测值会逐渐被屏蔽。
[1, 1, 1, 1]
预期输出为:
y_pred = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[-1, 1, 1, 1],
[-1,-1, 1, 1],
[-1,-1,-1, 1],
[-1,-1,-1,-1]])
true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))
for i in range(true.shape[0]):
print(true[i], pred[i], loss[i], sep='\t')
答案 1 :(得分:0)
为防止出现nan
,请遵循指示here。以下假设您希望掩码值(背景)等于零:
# Copied almost character-by-character (only change is default mask_value=0)
# from https://github.com/keras-team/keras/issues/7065#issuecomment-394401137
def masked_mse(mask_value=0):
"""
Made default mask_value=0; not sure this is necessary/helpful
"""
def f(y_true, y_pred):
mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
masked_squared_error = K.square(mask_true * (y_true - y_pred))
# in case mask_true is 0 everywhere, the error would be nan, therefore divide by at least 1
# this doesn't change anything as where sum(mask_true)==0, sum(masked_squared_error)==0 as well
masked_mse = K.sum(masked_squared_error, axis=-1) / K.maximum(K.sum(mask_true, axis=-1), 1)
return masked_mse
f.__name__ = str('Masked MSE (mask_value={})'.format(mask_value))
return f