我的样本有组ID。有没有办法将组ID传递给我的keras模型并根据组ID计算损失?
答案 0 :(得分:0)
您可以将组标签作为虚拟列插入目标向量 y 中,并将其从损失中排除。以下代码计算每组的均方误差并返回最大值。
def worst_case_group_mse(y_true, y_pred):
"""calculate mean squared error for each group separately and return worst value
Args:
y_true, y_pred (tf.Tensor):
last column corresponds to group index,
mean squared error calculated over all other columns
Returns:
tf.Tensor: maximum grouped mean squared error
"""
groups = tf.cast(y_true[:,-1], tf.int32)
y_true, y_pred = y_true[:,:-1], y_pred[:,:-1]
square = tf.math.square(y_pred - y_true)
unique, idx, count = tf.unique_with_counts(groups)
group_losses = tf.math.unsorted_segment_mean(square, idx, tf.size(unique))
group_losses = tf.math.reduce_mean(group_losses, axis=1)
return tf.math.reduce_max(group_losses)