用对数响应计算XGBoost中的预测贡献

时间:2020-07-03 09:06:23

标签: r xgboost

我正在使用xgboost库,并试图通过在predcontrib = TRUE函数中使用predict获得预测贡献,即不同的预测变量如何影响一个预测。

问题是我必须对数变换响应,以免得到负值。进行这种转换仍然可以让我获得贡献的值,但是我认为它们的数量是对数的。大多数值都低于1,因此采用这些值的指数不会导致预期的结果。下图是使用正态值,对数响应和对数响应训练的模型,该模型已使用指数函数进行了转换,说明了此问题。因此,底部图应该看起来更像顶部图,并且具有相同的比例。

plot

这是代码:

library(xgboost)
library(mlbench) # Just for data

data("BostonHousing")

par(mfrow = c(3, 1))

model <- xgboost(
    data = xgb.DMatrix(BostonHousing %>%
                           select(-medv) %>%
                           data.matrix(),
                       label = BostonHousing %>% pull(medv)),
    nround = 10,
    max_depth = 10,
    eta = 0.05,
    gamma = 5,
    colsample_bytree = 0.5,
    min_child_weight = 1,
    subsample = 0.8,
    verbose = FALSE
)

contrib <- predict(model,
                   BostonHousing %>%
                       dplyr::slice(1) %>% 
                       select(-medv) %>%
                       data.matrix(),
                   predcontrib = TRUE)

# Normal values
contrib

barplot(contrib,
        main = "Normal values that the transformed values should look like")


BostonHousing_log_y <- BostonHousing %>% 
    mutate(medv = log(medv + 1))

model_log_y <- xgboost(
    data = xgb.DMatrix(BostonHousing_log_y %>%
                           select(-medv) %>%
                           data.matrix(),
                       label = BostonHousing_log_y %>% pull(medv)),
    nround = 10,
    max_depth = 10,
    eta = 0.05,
    gamma = 5,
    colsample_bytree = 0.5,
    min_child_weight = 1,
    subsample = 0.8,
    verbose = FALSE
)

contrib_log_y <- predict(model_log_y,
                         BostonHousing %>%
                             dplyr::slice(1) %>% 
                             select(-medv) %>%
                             data.matrix(),
                         predcontrib = TRUE)

# Log-transformed values
contrib_log_y

barplot(contrib_log_y,
        main = "Logarithmic values before exponential transformation")

# Cannot transform to normal values since taking an exponent
# from a small value will shrink the value more

contrib_exp_y <- exp(contrib_log_y) + 1

contrib_exp_y

barplot(contrib_exp_y,
        main = "Values that could not be transformed back exponentially")

0 个答案:

没有答案