通过一对一静止和梯度增强树分类器进行多类分类

时间:2019-06-15 10:43:01

标签: pyspark apache-spark-mllib multiclass-classification

我想通过Gradient boosted tree构建一个PySpark分类器,以用于多类分类任务。我尝试过:

gb = GBTClassifier(maxIter=10)
ovr = OneVsRest(classifier=gb)
ovrModel = ovr.fit(trainingData)
gb_predictions = ovrModel.transform(valData)
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
gb_accuracy = evaluator.evaluate(gb_predictions)

运行上面的代码时,出现此错误:

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
AssertionError: Classifier <class 'pyspark.ml.classification.GBTClassifier'> doesn't extend from HasRawPredictionCol.

这与ovrModel = ovr.fit(trainingData)行有关,但是我不明白训练数据出了什么问题。

0 个答案:

没有答案