从PySpark ML中的DecisionTreeClassifier获取toDebugString

时间:2016-05-03 13:26:20

标签: python apache-spark pyspark

我使用像这样的管道训练了DecisionTreeClassifier模型:

from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier

cl = DecisionTreeClassifier(labelCol='target_idx', featuresCol='features')
pipe = Pipeline(stages=[target_index, assembler, cl])
model = pipe.fit(df_train)

# Prediction and model evaluation
predictions = model.transform(df_test)

其中阶段是StringIndexerVectorAssembler的实例。我现在可以评估模型的准确性,例如

mc_evaluator = MulticlassClassificationEvaluator(
labelCol="target_idx", predictionCol="prediction", metricName="precision"    )

accuracy = mc_evaluator.evaluate(predictions)
print("Test Error = {}".format(1.0 - accuracy))

大。现在我需要检查树模型结构。文档指向了一个名为toDebugString的属性,但ML DecisionTreeClassifier没有这个属性 - 它似乎只是MLLib DecisionTree分类器的属性。如何从ML版本的管道内的模型中获取树结构并绘制它?

1 个答案:

答案 0 :(得分:2)

这在pyspark中对我有用:

model.stages[2]._call_java('toDebugString')