从python中的xgboost提取决策规则

时间:2019-02-01 11:04:21

标签: python sas xgboost

我想为即将到来的模型在python中使用xgboost。但是,由于我们的生产系统位于SAS中,因此我试图从xgboost中提取决策规则,然后编写SAS评分代码以在SAS环境中实现此模型。

我已经通过多个链接到本不见了。下面是其中一些:

How to extract decision rules (features splits) from xgboost model in python3?

xgboost deployment

以上两个链接特别对Shiutang-Li给出的用于xgboost部署的代码有很大帮助。但是,我的预测分数并不完全匹配。

以下是到目前为止我尝试过的代码:

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.grid_search import GridSearchCV
%matplotlib inline
import graphviz
from graphviz import Digraph

#Read the sample iris data:
iris =pd.read_csv("C:\\Users\\XXXX\\Downloads\\Iris.csv")
#Create dependent variable:
iris.loc[iris["class"] != 2,"class"] = 0
iris.loc[iris["class"] == 2,"class"] = 1

#Select independent and dependent variable:
X = iris[["sepal_length","sepal_width","petal_length","petal_width"]]
Y = iris["class"]

xgdmat = xgb.DMatrix(X, Y) # Create our DMatrix to make XGBoost more efficient

#Build the sample xgboost Model:

our_params = {'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, 
             'objective': 'binary:logistic', 'max_depth':3, 'min_child_weight':1} 
Base_Model = xgb.train(our_params, xgdmat, num_boost_round = 10)

#Below code reads the dump file created by xgboost and writes a scoring code in SAS:

import re
def string_parser(s):
    if len(re.findall(r":leaf=", s)) == 0:
        out  = re.findall(r"[\w.-]+", s)
        tabs = re.findall(r"[\t]+", s)
        if (out[4] == out[8]):
            missing_value_handling = (" or missing(" + out[1] + ")")
        else:
            missing_value_handling = ""

        if len(tabs) > 0:
            return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                    '        if state = ' + out[0] + ' then do;\n' +
                    re.findall(r"[\t]+", s)[0].replace('\t', '    ') +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' ) 
        else:
            return ('        if state = ' + out[0] + ' then do;\n' +
                    '            if ' + out[1] + ' < ' + out[2] + missing_value_handling +
                    ' then state = ' + out[4] + ';' +  ' else state = ' + out[6] + ';\nend;' )
    else:
        out = re.findall(r"[\w.-]+", s)
        return (re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        if state = ' + out[0] + ' then\n    ' +
                re.findall(r"[\t]+", s)[0].replace('\t', '    ') + 
                '        value = value + (' + out[2] + ') ;\n')

def tree_parser(tree, i):
    return ('state = 0;\n'
             + "".join([string_parser(tree.split('\n')[i]) for i in range(len(tree.split('\n'))-1)]))

def model_to_sas(model, out_file):
    trees = model.get_dump()
    result = ["value = 0;\n"]
    with open(out_file, 'w') as the_file:
        for i in range(len(trees)):
            result.append(tree_parser(trees[i], i))
        the_file.write("".join(result))
        the_file.write("\nY_Pred1 = 1/(1+exp(-value));\n")
        the_file.write("Y_Pred0 = 1 - Y_pred1;") 

通话以上上述模块的创建SAS得分代码:

model_to_sas(Base_Model, 'xgb_scr_code.sas')

不幸的是,我无法提供上述模块生成的完整SAS代码。但是,如果我们仅使用一个树形代码来构建模型,请在SAS代码下面找到:

value = 0;
state = 0;
if state = 0 then
    do;
        if sepal_width < 2.95000005 or missing(sepal_width) then state = 1;
        else state = 2;
    end;
if state = 1 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 3;
        else state = 4;
    end;

if state = 3 then   value = value + (0.1586207);
if state = 4 then   value = value + (-0.127272725);
if state = 2 then
    do;
        if petal_length < 3 or missing(petal_length) then state = 5;
        else state = 6;
    end;
if state = 5 then   value = value + (-0.180952385);
if state = 6 then
    do;
        if petal_length < 4.75 or missing(petal_length) then state = 7;
        else state = 8;
    end;
if state = 7 then   value = value + (0.142857149);
if state = 8 then   value = value + (-0.161290333);

Y_Pred1 = 1/(1+exp(-value));
Y_Pred0 = 1 - Y_pred1;

以下是第一棵树的转储文件输出:

booster[0]:
    0:[sepal_width<2.95000005] yes=1,no=2,missing=1
        1:[petal_length<4.75] yes=3,no=4,missing=3
            3:leaf=0.1586207
            4:leaf=-0.127272725
        2:[petal_length<3] yes=5,no=6,missing=5
            5:leaf=-0.180952385
            6:[petal_length<4.75] yes=7,no=8,missing=7
                7:leaf=0.142857149
                8:leaf=-0.161290333

所以基本上,我试图做的是,保存在变量“状态”的节点数量,并相应地访问叶子节点(这是我从文章Shiutang丽学会了上面的链接提到)。< / p>

这是我面临的问题:

对于大约40棵树,预测分数完全匹配。例如,请参见以下内容:

情况1:

使用python预测的10棵树的值:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

输出:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.4021197

使用SAS对10棵树的预测值:

Average Y predicted:  0.4021197

情况2:

使用Python 100种树木

预测值:

Y_pred1 = Base_Model.predict(xgdmat)

print("Development- Y_Actual: ",np.mean(Y)," Y predicted: ",np.mean(Y_pred1))

输出:

Average- Y_Actual:  0.3333333333333333  Average Y predicted:  0.33232176

使用SAS对100棵树的预测值:

Average Y predicted:  0.3323159

正如你可以看到的分数不完全匹配(匹配到4个小数点)100倍的树木。另外,我在分数差异很大(即分数偏差超过10%)的大型文件上尝试过这种方法。

任何人都可以让我指出代码中的任何错误,以便分数可以完全匹配。以下是我的一些查询:

1)我的分数计算正确吗?

2)我发现了与gamma(正则化术语)有关的东西。它是否影响使用叶值xgboost计算分数的方式。

3)转储文件给出的叶子值是否会四舍五入,从而造成此问题

另外,我将理解任何其他方法来从解析转储文件分开执行此任务。

P.S:我只有SAS EG和没有获得SAS EM或SAS IML。

3 个答案:

答案 0 :(得分:1)

我在获得匹配分数方面也有类似的经验。
我了解到,除非您修复ntree_limit选项以匹配模型拟合期间使用的n_estimators,否则评分可能会提前终止。

df['score']= xgclfpkl.predict(df[xg_features], ntree_limit=500)

开始使用ntree_limit之后,我开始获得匹配分数。

答案 1 :(得分:0)

我有点想将其合并到自己的代码中。

我发现缺少处理方面存在一个小问题。

如果您有类似的逻辑,这似乎很好用

    <html>
<body>
<h2>Hello World!</h2>
<form action="getuserdata" method="get">
<input type= "text" name="userid"><br>
<input type= "text" name="firstname"><br>
<input type="submit">
</form>
</body>
</html>

但是说丢失的组应该进入状态6而不是状态5。然后您将获得如下代码:

    <!DOCTYPE web-app PUBLIC
 "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
 "http://java.sun.com/dtd/web-app_2_3.dtd" >

<web-app>
  <display-name>Archetype Created Web Application</display-name>

  <servlet>
    <servlet-name>GmUserData</servlet-name>
    <servlet-class>com.Gmaps.web.controller.GmUserData</servlet-class>
  </servlet>

  <servlet-mapping>
    <servlet-name>GmUserData</servlet-name>
    <url-pattern>/getuserdata</url-pattern>
  </servlet-mapping>
</web-app>

在这种情况下,if petal_length < 3 or missing(petal_length) then state = 5; else state = 6; 处于什么状态? 好吧,这里仍然进入状态5(而不是预期的状态6),因为SAS丢失被归类为小于任何数字。

要解决此问题,您可以将所有缺少的值分配给999999999999999(由于XGBoost格式始终使用小于(<),所以请选择一个较大的数字),然后替换

if petal_length < 3 then state = 5;
        else state = 6;

使用

petal_length = missing (.)

在您的missing_value_handling = (" or missing(" + out[1] + ")") 中。

答案 2 :(得分:0)

点对-

首先,与叶子返回值匹配的正则表达式捕获转储中的“ e小数”科学计数法(默认)。明确的示例(第二个是正确的修改!)-

s = '3:leaf=9.95066429e-09'
out = re.findall(r"[\d.-]+", s)
out2 = re.findall(r"-?[\d.]+(?:e-?\d+)?", s)
out2,out

(易于修复,但不能准确地发现我的模型中有一片叶子受到影响!)

第二,问题是关于二进制的,但是在多类目标中,转储中每个类都有单独的树,因此,您总共有T*C棵树,其中T是提升轮数C是类数。对于类c(在{0,1,...,C-1}中),您需要评估i*C +c的树i = 0,...,T-1(并对其末端叶子求和)。然后将其最大化以匹配xgb的预测。