Keras:计算损失为*中位数*跨数据点而不是平均值

时间:2017-05-31 01:07:06

标签: optimization deep-learning keras

Keras losses页面说如果我们有自定义丢失函数,那么"实际优化的目标是所有数据点上输出数组的平均值。"有没有什么办法可以在所有数据点(而不是平均值)上优化输出数组的中值

1 个答案:

答案 0 :(得分:2)

为此,您需要降低到张量流水平

import keras
import tensorflow


def pick_median(arg_tensor):
    the_upper_tensor = tensorflow.contrib.distributions.percentile(arg_tensor, 50, interpolation='higher')
    the_lower_tensor = tensorflow.contrib.distributions.percentile(arg_tensor, 50, interpolation='lower')

    final_tensor = (the_upper_tensor + the_lower_tensor) / 2
    # print(the_count.eval(session=keras.backend.get_session()))

    return final_tensor

这就是您定义median_squared_error损失函数的方式:

def median_squared_error(arg_y_true,
                         arg_y_pred):
    final_tensor = keras.backend.square(arg_y_pred - arg_y_true)
    final_tensor = pick_median(arg_tensor=final_tensor)
    return final_tensor