将所有观测值的Shapley原因码附加到整个数据

时间:2018-11-20 02:16:14

标签: r tidyverse purrr

这是我的代码,用于在mtcars数据集上获得排名前5位的Shaply原因代码。

<?php

$string =
"Category
Business
Dates
StatusOpen
Closing Information
Location
National
South-East Asia
New South Wales
Victoria
Sections
General
Difficulty Rating
Administrator";

$initialValue = false;
$lastValue = false;
$arResult = [];
$arValue = explode("\n", $string);

foreach($arValue as $value) {
    $value = trim($value);
    if ($value == "Location") {
        $initialValue = true;
    } else if ($value == "Sections") {
        $lastValue = true;
    } else if ($initialValue == true && $lastValue == false) {
        $arResult[] = $value;
    }
}

echo implode(",",$arResult); // National,South-East Asia,New South Wales,Victoria
  1. 如何获取所有观察的原因代码(而不是上面代码mtcars [1,]中倒数第二行的一次)?
  2. 然后,使用id附加/ left_join shapleyresults到整个数据集吗?

    数据集的长度将增加5倍。我们应该在这里使用purrr来做到这一点吗?

1 个答案:

答案 0 :(得分:0)

我找到了解决方法。

#install.packages("randomForest"); install.packages("tidyverse"); install.packages("iml")
library(tidyverse); library(iml); library(randomForest) 

set.seed(42)

mtcars1 <- mtcars %>%  mutate(vs = as.factor(vs),
                              id = row_number())

x <- "vs"
y <- paste0(setdiff(setdiff(names(mtcars1), "vs"), "id"), collapse = "+")

rf = randomForest(as.formula(paste0(x, "~ ", y)), data = mtcars1, ntree = 50)

predictor <- Predictor$new(rf, data = mtcars1, y = mtcars1$vs)

shapelyresults <- map_dfr(1:nrow(mtcars), ~(Shapley$new(predictor, x.interest = mtcars1[.x,]) %>% 
                                              .$results %>% 
                                              as_tibble() %>% 
                                              arrange(desc(phi)) %>% 
                                              slice(1:5) %>% 
                                              select(feature.value, phi) %>%
                                              mutate(id = .x)))

final_data <- mtcars1 %>% left_join(shapelyresults, by = "id")