如何在pyspark中打印具有功能名称的随机森林的决策路径?

时间:2018-08-01 13:45:58

标签: python apache-spark pyspark

如何修改代码以使用功能名称而不是数字打印决策路径。

import pandas as pd
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, Transformer
from pyspark.sql import DataFrame
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler

data = pd.DataFrame({
    'ball': [0, 1, 2, 3],
    'keep': [4, 5, 6, 7],
    'hall': [8, 9, 10, 11],
    'fall': [12, 13, 14, 15],
    'mall': [16, 17, 18, 10],
    'label': [21, 31, 41, 51]
})

df = spark.createDataFrame(data)

assembler = VectorAssembler(
    inputCols=['ball', 'keep', 'hall', 'fall'], outputCol='features')
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='label')

pipeline = Pipeline(stages=[assembler, dtc]).fit(df)
transformed_pipeline = pipeline.transform(df)

ml_pipeline = pipeline.stages[1]
print(ml_pipeline.toDebugString)

输出:

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_48b3a34f6fb1f1338624) of depth 3 with 7 nodes   If (feature 0 <= 0.5)    Predict: 21.0   Else (feature 0 >
0.5)    If (feature 0 <= 1.5)
    Predict: 31.0    Else (feature 0 > 1.5)
    If (feature 0 <= 2.5)
     Predict: 41.0
    Else (feature 0 > 2.5)
     Predict: 51.0

2 个答案:

答案 0 :(得分:1)

一种选择是手动替换字符串中的文本。为此,我们可以将通过messaging.setBackgroundMessageHandler(function (payload) { var realPush = true; if(realPush) { const notificationOptions = { body: "It is a REAL push", data:"true" }; //We display the notification return self.registration.showNotification(title, notificationOptions); }else { const notificationOptions = { body: "It is a SILENT push", data:"false" }; //We display a fake notification return self.registration.showNotification('To delete',notificationOptions).then(function () { self.registration.getNotifications().then(notifications => { console.log(notifications); for (var i =0;i<notifications.length;i++) { if(notifications[i].data != "true") { //then we destroy the fake notification immedialtely ! notifications[i].close(); } } }) }); } }); 传递的值存储在列表inputCols中,然后每次将模式input_cols替换为列表{的第feature i个元素{1}}。

i

输出:

input_cols

希望这会有所帮助!

答案 1 :(得分:0)

@Florian:当特征数量很大(超过 9 个)时,上面的代码将不起作用。相反,请使用正则表达式使用以下内容。

tree_to_json = mod.stages[-1].toDebugString
for (index, feat) in index_feature_name_tuple:
  pattern = '\((?P<index>feature ' + str(index) + ')' + ' (?P<rest>.*)\)'
  tree_to_json = re.sub(pattern, f'({feat} \g<rest>)', tree_to_json)

print(tree_to_json)

tree_to_json 是原始规则,应转移到具有特征名称的规则。 index_feature_name_tuple 是元组列表,其中每个元组的第一个元素是特征的索引,第二个元素代表特征的名称。您可以使用以下脚本获取 index_feature_name_tuple

df_fitted.schema['features'].metadata["ml_attr"]["attrs"]

其中 df_fitted 是将管道拟合到数据框后转换的数据框。