我使用像这样的管道训练了DecisionTreeClassifier
模型:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
cl = DecisionTreeClassifier(labelCol='target_idx', featuresCol='features')
pipe = Pipeline(stages=[target_index, assembler, cl])
model = pipe.fit(df_train)
# Prediction and model evaluation
predictions = model.transform(df_test)
其中阶段是StringIndexer
和VectorAssembler
的实例。我现在可以评估模型的准确性,例如
mc_evaluator = MulticlassClassificationEvaluator(
labelCol="target_idx", predictionCol="prediction", metricName="precision" )
accuracy = mc_evaluator.evaluate(predictions)
print("Test Error = {}".format(1.0 - accuracy))
大。现在我需要检查树模型结构。文档指向了一个名为toDebugString
的属性,但ML DecisionTreeClassifier
没有这个属性 - 它似乎只是MLLib DecisionTree
分类器的属性。如何从ML版本的管道内的模型中获取树结构并绘制它?
答案 0 :(得分:2)
这在pyspark中对我有用:
model.stages[2]._call_java('toDebugString')