如何使用Apache Spark ML库对随机森林进行网格搜索

时间:2019-01-15 21:37:49

标签: apache-spark apache-spark-mllib

我想在Apache Spark中的随机森林模型上执行网格搜索。但是我找不到一个这样做的例子。在样本数据上有什么示例可以使用Grid Search进行超参数调整吗?

1 个答案:

答案 0 :(得分:1)

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder


rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10)
pipeline = Pipeline(stages=[rf])
paramGrid = ParamGridBuilder().addGrid(rf.numTrees, [10, 30]).build()

crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(),
                          numFolds=2) 

cvModel = crossval.fit(training_df)

超参数和网格是在addGrid方法中定义的