如何从决策树中提取规则spark MLlib

时间:2015-08-03 08:04:29

标签: apache-spark apache-spark-mllib

我正在使用Spark MLlib 1.4.1来创建decisionTree模型。现在我想从决策树中提取规则。


3 个答案:

The documentation is here,其中包含一个示例数据示例,您可以在命令行中检查输出格式。在这里,我格式化了您可以直接过去并运行的脚本。

from numpy import array
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree

data = [
LabeledPoint(0.0, [0.0]),
LabeledPoint(1.0, [1.0]),
LabeledPoint(1.0, [2.0]),
LabeledPoint(1.0, [3.0])

model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})



DecisionTreeModel classifier of depth 1 with 3 nodes
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0 



dtModel = DecisionTree.trainClassifier(parsedTrainData, numClasses=7, categoricalFeaturesInfo={},impurity='gini', maxDepth=20, maxBins=24)

modelFile = ~/decisionTreeModel.txt"
f = open(modelFile,"w") 


DecisionTreeModel classifier of depth 20 with 20031 nodes
  If (feature 0 <= -35.0)
   If (feature 24 <= 176.0)
    If (feature 0 <= -200.0)
     If (feature 29 <= 109.0)
      If (feature 6 <= -156.0)
       If (feature 9 <= 0.0)
        If (feature 20 <= -116.0)
         If (feature 16 <= 203.0)
          If (feature 11 <= 163.0)
           If (feature 5 <= 384.0)
            If (feature 15 <= 325.0)
             If (feature 13 <= -248.0)
              If (feature 20 <= -146.0)
               Predict: 0.0
              Else (feature 20 > -146.0)
               If (feature 19 <= -58.0)
                Predict: 6.0
               Else (feature 19 > -58.0)
                Predict: 0.0
             Else (feature 13 > -248.0)
              If (feature 9 <= -26.0)
               Predict: 0.0
              Else (feature 9 > -26.0)
               If (feature 10 <= 218.0)

import networkx as nx


modeldf = spark.read.parquet(location+"/data/*")

noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()


features = ["feature"+str(i) for i in range(0,700)]


G = nx.DiGraph()
for rw in noderows:

    if rw['leftChild'] < 0 and rw['rightChild'] < 0:

        G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])


        G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])

for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():

    tempnode = G.nodes(data="True")[rw['id']][1]


    G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

    G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

上面的代码将所有规则转换为图形网络。 要以if和else格式打印所有规则,我们可以找到所有叶节点的路径,并列出边缘原因以提取最终规则

nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]

for n in nodes:

    p = nx.shortest_path(G,0,n)

    print("Rule No:",n)

    print(" & ".join([G.get_edge_data(p[i],p[i+1])['reason'] for i in range(0,len(p)-1)]))





feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature367小于   [1.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature367大于   [1.0]和feature318小于[0.0]且feature385小于[0.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature367大于   [1.0]和feature318小于[0.0]且Feature385大于[0.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature367大于   [1.0]和feature318大于[0.0],feature266小于[0.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature367大于   [1.0]和feature318大于[0.0]和feature266大于[0.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature158小于   [1.0]和feature274小于[0.0]和feature89小于[1.0]




feature457小于[0.0]和feature353小于[0.0]和feature185   小于[1.0]和feature294小于[1.0]和feature158小于   [1.0]和feature274小于[0.0],feature89大于[1.0]


from pyspark.sql.functions import to_date,datediff,lit,udf,sum,avg,col,count,lag
from pyspark.sql.types import StringType,LongType,StructType,StructField,DateType,IntegerType,DoubleType
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
import pandas as pd
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, lit, avg, max, min
from pyspark.sql.types import StringType, ArrayType, DoubleType
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
import operator

import ast

operators = {
            ">=": operator.ge,
            "<=": operator.le,
            ">": operator.gt,
            "<": operator.lt,
            "==": operator.eq,
            'and': operator.and_,
            'or': operator.or_

data = pd.DataFrame({
    'ball': [0, 1, 1, 3, 1, 0, 1, 3],
    'keep': [4, 5, 6, 7, 7, 4, 6, 7],
    'hall': [8, 9, 10, 11, 2, 6, 10, 11],
    'fall': [12, 13, 14, 15, 15, 12, 14, 15],
    'mall': [16, 17, 18, 10, 10, 16, 18, 10],
    'label': [21, 31, 41, 51, 51, 51, 21, 31]
df = spark.createDataFrame(data)

f_list = ['ball','keep','mall','hall','fall']
 assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',

dt = DecisionTreeClassifier(featuresCol='features', labelCol='label')

pipeline = Pipeline(stages=[assemble_numerical_features, dt])
model = pipeline.fit(df)
df = model.transform(df)
dt_m = model.stages[-1]

# Step 1: convert model.debugString output to dictionary of nodes and children
def parse_debug_string_lines(lines):
    block = []
    while lines:

        if lines[0].startswith('If'):
            bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
            block.append({'name': bl, 'children': parse_debug_string_lines(lines)})

            if lines[0].startswith('Else'):
                be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
                block.append({'name': be, 'children': parse_debug_string_lines(lines)})
        elif not lines[0].startswith(('If', 'Else')):
            block2 = lines.pop(0)
            block.append({'name': block2})
    return block

def debug_str_to_json(debug_string):
    data = []
    for line in debug_string.splitlines():
        if line.strip():
            line = line.strip()
        if not line: break
    json = {'name': 'Root', 'children': parse_debug_string_lines(data[1:])}
    return json

# Step 2 : Using metadata stored in features column, build dictionary which maps each feature in features column of df to its index in feature vector
f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]
f_index_to_name_dict = {}
for f_type, f_list in f_type_to_flist_dict.items():

    for f in f_list:
        f_index = f['idx']
        f_name = f['name']
        f_index_to_name_dict[f_index] = f_name

def generate_explanations(dt_as_json, df:DataFrame, f_index_to_name_dict, operators):

    dt_as_json_str = str(dt_as_json)
    cond_parsing_exception_occured = False

    df = df.withColumn('features'+'_list',
                            udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
    # step 3 : parse and check whether current instance follows condition in perticular node
    def parse_validate_cond(cond: str, f_vector: list):

        cond_parts = cond.split()
        condition_f_index = int(cond_parts[1])
        condition_op = cond_parts[2]
        condition_value = float(cond_parts[3])

        f_value = f_vector[condition_f_index]
        f_name = f_index_to_name_dict[condition_f_index].replace('numerical_features_', '').replace('encoded_numeric_', '').lower()

        if operators[condition_op](f_value, condition_value):
            return True, f_name + ' ' + condition_op + ' ' + str(round(condition_value,2))

        return False, ''
# Step 4 : extract rules for an instance in a dataframe, going through nodes in a tree where instance is satisfying the rule, finally leading to a prediction node
    def extract_rule(dt_as_json_str: str, f_vector: list, rule=""):
        # variable declared in outer function is read only
        # in inner if not explicitly declared to be nonlocal
        nonlocal cond_parsing_exception_occured

        dt_as_json = ast.literal_eval(dt_as_json_str)
        child_l = dt_as_json['children']

        for child in child_l:
            name = child['name'].strip()

            if name.startswith('Predict:'):
                # remove last comma
                return rule[0:rule.rindex(',')]

            if name.startswith('feature'):
                    res, cond = parse_validate_cond(child['name'], f_vector)
                except Exception as e:
                    res = False
                    cond_parsing_exception_occured = True
                if res:
                    rule += cond +', '
                    rule = extract_rule(str(child), f_vector, rule=rule)
        return rule

    df = df.withColumn('explanation',
                        udf(lambda dt, fv:extract_rule(dt, fv) ,StringType())
                        (lit(dt_as_json_str), df['features'+'_list'])
    # log exception occured while trying to parse
    # condition in decision tree node
    if cond_parsing_exception_occured:
        print('some node in decision tree has unexpected format')

    return df

df = generate_explanations(debug_str_to_json(dt_m.toDebugString), df, f_index_to_name_dict, operators)
rows = df.select(['ball','keep','mall','hall','fall','explanation','prediction']).collect()

output :
[Row(ball=0, keep=4, mall=16, hall=8, fall=12, explanation='hall > 7.0, mall > 13.0, ball <= 0.5', prediction=21.0),
 Row(ball=1, keep=5, mall=17, hall=9, fall=13, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep <= 5.5', prediction=31.0),
 Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
 Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0),
 Row(ball=1, keep=7, mall=10, hall=2, fall=15, explanation='hall <= 7.0', prediction=51.0),
 Row(ball=0, keep=4, mall=16, hall=6, fall=12, explanation='hall <= 7.0', prediction=51.0),
 Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
 Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0)]

output of dt_m.toDebugString:
'DecisionTreeClassificationModel (uid=DecisionTreeClassifier_2a17ae7633b9) of depth 4 with 9 nodes\n  If (feature 3 <= 7.0)\n   Predict: 51.0\n  Else (feature 3 > 7.0)\n   If (feature 2 <= 13.0)\n    Predict: 31.0\n   Else (feature 2 > 13.0)\n    If (feature 0 <= 0.5)\n     Predict: 21.0\n    Else (feature 0 > 0.5)\n     If (feature 1 <= 5.5)\n      Predict: 31.0\n     Else (feature 1 > 5.5)\n      Predict: 21.0\n'

output of debug_str_to_json(dt_m.toDebugString):
{'name': 'Root',
'children': [{'name': 'feature 3 <= 7.0',
   'children': [{'name': 'Predict: 51.0'}]},
  {'name': 'feature 3 > 7.0',
   'children': [{'name': 'feature 2 <= 13.0',
     'children': [{'name': 'Predict: 31.0'}]},
    {'name': 'feature 2 > 13.0',
     'children': [{'name': 'feature 0 <= 0.5',
       'children': [{'name': 'Predict: 21.0'}]},
      {'name': 'feature 0 > 0.5',
       'children': [{'name': 'feature 1 <= 5.5',
         'children': [{'name': 'Predict: 31.0'}]},
        {'name': 'feature 1 > 5.5',
         'children': [{'name': 'Predict: 21.0'}]}]}]}]}]}