我正在使用 pyspark.ml.classification 中的 GBT 分类器。我正在用火车数据拟合GBT模型。我在这里使用虹膜数据集。我的标签有3个类别,即; 0,1和2 。我有三个班级,并且正在使用 MulticlassClassificationEvaluator 。根据我引用的文档,https://spark.apache.org/docs/2.1.3/ml-classification-regression.html#gradient-boosted-tree-classifier他们提到了 MulticlassClassificationEvaluator 。当我尝试使用相同的内容时,出现错误。有人可以帮我吗? 我的train_data看起来像这样
+-----------------+-----+
| features|label|
+-----------------+-----+
|[4.3,3.0,1.1,0.1]| 0.0|
|[4.4,3.0,1.3,0.2]| 0.0|
|[4.4,3.2,1.3,0.2]| 0.0|
|[4.5,2.3,1.3,0.3]| 0.0|
|[4.6,3.1,1.5,0.2]| 0.0|
|[4.6,3.2,1.4,0.2]| 0.0|
|[4.6,3.4,1.4,0.3]| 0.0|
|[4.6,3.6,1.0,0.2]| 0.0|
|[4.7,3.2,1.6,0.2]| 0.0|
|[4.8,3.0,1.4,0.1]| 0.0|
|[4.8,3.0,1.4,0.3]| 0.0|
|[4.8,3.4,1.6,0.2]| 0.0|
|[4.8,3.4,1.9,0.2]| 0.0|
|[4.9,2.5,4.5,1.7]| 2.0|
|[4.9,3.0,1.4,0.2]| 0.0|
|[4.9,3.1,1.5,0.1]| 0.0|
|[4.9,3.1,1.5,0.1]| 0.0|
|[5.0,2.3,3.3,1.0]| 1.0|
|[5.0,3.0,1.6,0.2]| 0.0|
|[5.0,3.3,1.4,0.2]| 0.0|
+-----------------+-----+
这是我的代码:
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
(trainingData, testData) = data.randomSplit([0.7, 0.3])
gbt = GBTClassifier(maxIter=10)
gbt.fit(trainingData)
错误:
Py4JJavaError: An error occurred while calling o57.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure:
Task 0 in stage 14.0 failed 1 times, most recent failure: Lost task 0.0 in stage 14.0 (TID 14, localhost, executor driver):
java.lang.IllegalArgumentException:
requirement failed: GBTClassifier was given dataset with invalid label 2.0. Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.