我正在训练网络以使用Keras执行语义分割。通常,数据集包含void / unlabeled类。在论文中,这门课总是被忽略。这意味着网络不会将任何像素预测为无效,并且在计算指标时会忽略实际无效的像素。
简而言之,我希望网络不能预测给定的类。在混淆矩阵中,这意味着给定类的一行0:
[[ 0 0 0]
[ 553 109791 310]
[ 121 1756 264292]]
由于class_weight
中的fit_generator
参数不支持3维数据,而我的输入数据是4D(批量,高度,宽度,类),我现在正在尝试定制计算加权分类交叉熵的损失函数。实施:
def weighted_categorical_crossentropy(class_weights):
tf_weights = tf.convert_to_tensor(class_weights, np.float32)
def run(y_true, y_pred):
# scale preds so that the class probas of each sample sum to 1
y_pred /= tf.reduce_sum(y_pred, -1, True)
# manual computation of crossentropy
_epsilon = tf.convert_to_tensor(1e-7, y_pred.dtype.base_dtype)
output = tf.clip_by_value(y_pred, _epsilon, 1. - _epsilon)
return - tf.reduce_sum(tf.multiply(y_true * tf.log(output), tf_weights), -1)
return run
由于我的模型的最后一层是Softmax图层,因此实现与tensorflow backend和TensorFlow: Implementing a class-wise weighted cross entropy loss?
中的categorical_crossentropy
非常相似。
用法:model.compile(optimizer=optim, loss=weighted_categorical_crossentropy(class_weights), metrics='accuracy')
。
根据我的测试,属于具有类权重0的void类的像素总是如预期的那样丢失0。问题是网络仍然将像素预测为无效。
我做错了什么?