Python,XGBoost,自定义目标函数:即使在非零梯度的情况下,预测在迭代过程中也保持不变

时间:2018-08-23 23:43:22

标签: xgboost loss-function

我正在尝试使用xgboost自定义损失函数。我将介绍四种目标选择的结果:

    由xgboost定义的
  • 'reg:linear'使用mse
  • mse_approx_obj 使用mse进行自定义,在下面找到代码
  • mae_approx_obj 使用“ mean_absolute_error”自定义,在下面找到代码
  • pseudohuber_approx_obj Wiki

从下面的结果中可以看到,对于xgboost定义的mse和自定义定义的mse,该算法运行良好。但是,对于自定义定义的mae和平滑的一阶损失函数(Huber),该算法的行为不稳定。例如,对于mae和Huber,即使梯度值很大,预测也不会改变。这对我来说似乎很奇怪。有指针吗?

def mse_approx_obj(dtrain, preds):
    d = preds - dtrain   
    grad_mse = d
    hess_mse = np.full(d.shape,1.0)

    return grad_mse, hess_mse

def mae_approx_obj(dtrain, preds):
    d = preds - dtrain
    grad_mae = np.array(d)
    grad_mae[grad_mae > 0] = 1.
    grad_mae[grad_mae <= 0] = -1.
    hess_mae = np.full(d.shape, 0.0)

    return grad_mae, hess_mae

def pseudohuber_approx_obj(dtrain,preds):
    d = preds- dtrain
    h = 1  #h is the delta
    scale = 1 + (d / h) ** 2
    scale_sqrt = np.sqrt(scale)
    grad = d / scale_sqrt
    hess = 1 / scale / scale_sqrt

    return grad, hess

 def model(init_params, training_data, validation_data):

     xgb1 = XGBRegressor(learning_rate=0.1,n_estimators=5, max_depth=8, min_child_weight=1, gamma=0, subsample=0.8, colsample_bytree=0.8, objective = mae_approx_obj, seed=2)

     x_train, y_train = training_data[Xfeatures], training_data[target]
     x_test, y_test = validation_data[Xfeatures], validation_data[target]

     xgb1.fit(x_train,y_train, eval_set=[(x_train,y_train), (x_test,y_test)], eval_metric='rmse', verbose=True)

     return xgb1

xgb1 = model(init_params, training_data, validation_data)            

四个目标的结果分别为:

[0] validation_0-rmse:61.5518   validation_1-rmse:57.0926
[1] validation_0-rmse:55.4669   validation_1-rmse:51.2765
[2] validation_0-rmse:49.9936   validation_1-rmse:46.3276
[3] validation_0-rmse:45.0812   validation_1-rmse:41.9738
[4] validation_0-rmse:40.6609   validation_1-rmse:37.9743`

[0] validation_0-rmse:61.5506   validation_1-rmse:57.0223
[1] validation_0-rmse:55.4673   validation_1-rmse:51.5381
[2] validation_0-rmse:49.9943   validation_1-rmse:46.6359
[3] validation_0-rmse:45.0801   validation_1-rmse:42.2759
[4] validation_0-rmse:40.6617   validation_1-rmse:38.3059

对于mae_approx_obj,每次调用mae_approx_obj来帮助调试过程时,我都会打印pred和grad_mae。

[ 0.5  0.5  0.5 ...,  0.5  0.5  0.5]
[-1. -1. -1. ..., -1. -1. -1.]
[0] validation_0-rmse:68.3176   validation_1-rmse:63.0391
[ 0.5  0.5  0.5 ...,  0.5  0.5  0.5]
[-1. -1. -1. ..., -1. -1. -1.]
[1] validation_0-rmse:68.3176   validation_1-rmse:63.0391
[ 0.5  0.5  0.5 ...,  0.5  0.5  0.5]
[-1. -1. -1. ..., -1. -1. -1.]
[2] validation_0-rmse:68.3176   validation_1-rmse:63.0391
[ 0.5  0.5  0.5 ...,  0.5  0.5  0.5]
[-1. -1. -1. ..., -1. -1. -1.]
[3] validation_0-rmse:68.3176   validation_1-rmse:63.0391
[ 0.5  0.5  0.5 ...,  0.5  0.5  0.5]
[-1. -1. -1. ..., -1. -1. -1.]
[4] validation_0-rmse:68.3176   validation_1-rmse:63.0391

对于huber_approx_obj

[ 60.79294968  71.14537811  68.94273376 ...,  70.04405212  70.04405212 68.72246552]
[-1.99890065 -1.99919903 -1.9991467  ..., -1.9991734  -1.999173 -1.9991411]
[0] validation_0-rmse:1138.79   validation_1-rmse:1106.67

[ 60.79294968  71.14537811  68.94273376 ..., 70.04405212  70.04405212 68.72246552]
[ 1.99999702  1.99999702  1.99999702 ...,  1.99999702  1.9999970 1.99999702]
[1] validation_0-rmse:149.637   validation_1-rmse:261.029

[60.79294968  71.14537811  68.94273376 ...,  70.04405212  70.04405212 68.72246552]
[-1.9996798  -1.9997319  -1.99972188 ..., -1.99972689 -1.99972689 -1.99972081]
[2] validation_0-rmse:149.637   validation_1-rmse:261.029

[ 60.79294968  71.14537811  68.94273376 ...,  70.04405212  70.04405212 68.72246552]
[-1.9996798  -1.9997319  -1.99972188 ..., -1.99972689 -1.99972689 -1.99972081]
[3] validation_0-rmse:149.637   validation_1-rmse:261.029

[60.79294968  71.14537811  68.94273376 ...,  70.04405212  70.04405212 68.72246552]
[-1.9996798 -1.9997319 -1.99972188 ..., -1.99972689 -1.99972689 -1.999720]
[4] validation_0-rmse:149.637   validation_1-rmse:261.029

PS:有一个非常相似的问题here,但没有答案。已经有一段时间了,所以再次询问是否缺少我想要的东西。

0 个答案:

没有答案