我正在使用PySpark进行机器学习,我想训练决策树分类器,随机森林和梯度提升树。我想尝试不同的最大深度值,并通过网格搜索和交叉验证选择最佳深度值。但是,Spark告诉我,DecisionTree目前只支持maxDepth< = 30.将它限制为30的原因是什么?有没有办法增加它?我正在使用它与文本数据,我的特征向量是TF-IDF,所以我想尝试更高的值来获得最大深度。 Spark网站上的示例代码进行了一些修改:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
# Index labels, adding metadata to the label column.
# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label",
outputCol="indexedLabel").fit(data)
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a RandomForest model.
rf = RandomForestClassifier(labelCol="indexedLabel",
featuresCol="indexedFeatures", numTrees=500)
# Convert indexed labels back to original labels.
labelConverter = IndexToString(inputCol="prediction",
outputCol="predictedLabel",
labels=labelIndexer.labels)
# Chain indexers and forest in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])
paramGrid_rf = ParamGridBuilder() \
.addGrid(rf.maxDepth, [50,100,150,250,300]) \
.build()
crossval_rf = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid_rf,
evaluator=BinaryClassificationEvaluator(),
numFolds= 5)
cvModel_rf = crossval_rf.fit(trainingData)
上面的代码给出了以下错误消息。
Py4JJavaError:调用o12383.fit时发生错误。 :java.lang.IllegalArgumentException:要求失败:DecisionTree目前仅支持maxDepth< = 30,但是maxDepth = 50。
答案 0 :(得分:0)
......当前实施施加了maxDepth <= 30的限制:
您可以在github论坛上请求增加该限制!