GBT分类器不适用于pyspark中的Multiclass Evaluator

时间:2018-11-06 07:13:08

标签: python apache-spark machine-learning pyspark apache-spark-ml

我正在使用 pyspark.ml.classification 中的 GBT 分类器。我正在用火车数据拟合GBT模型。我在这里使用虹膜数据集。我的标签有3个类别,即; 0,1和2 。我有三个班级,并且正在使用 MulticlassClassificationEvaluator 。根据我引用的文档,https://spark.apache.org/docs/2.1.3/ml-classification-regression.html#gradient-boosted-tree-classifier他们提到了 MulticlassClassificationEvaluator 。当我尝试使用相同的内容时,出现错误。有人可以帮我吗? 我的train_data看起来像这样

+-----------------+-----+
|         features|label|
+-----------------+-----+
|[4.3,3.0,1.1,0.1]|  0.0|
|[4.4,3.0,1.3,0.2]|  0.0|
|[4.4,3.2,1.3,0.2]|  0.0|
|[4.5,2.3,1.3,0.3]|  0.0|
|[4.6,3.1,1.5,0.2]|  0.0|
|[4.6,3.2,1.4,0.2]|  0.0|
|[4.6,3.4,1.4,0.3]|  0.0|
|[4.6,3.6,1.0,0.2]|  0.0|
|[4.7,3.2,1.6,0.2]|  0.0|
|[4.8,3.0,1.4,0.1]|  0.0|
|[4.8,3.0,1.4,0.3]|  0.0|
|[4.8,3.4,1.6,0.2]|  0.0|
|[4.8,3.4,1.9,0.2]|  0.0|
|[4.9,2.5,4.5,1.7]|  2.0|
|[4.9,3.0,1.4,0.2]|  0.0|
|[4.9,3.1,1.5,0.1]|  0.0|
|[4.9,3.1,1.5,0.1]|  0.0|
|[5.0,2.3,3.3,1.0]|  1.0|
|[5.0,3.0,1.6,0.2]|  0.0|
|[5.0,3.3,1.4,0.2]|  0.0|
+-----------------+-----+

这是我的代码:

from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
(trainingData, testData) = data.randomSplit([0.7, 0.3])
gbt = GBTClassifier(maxIter=10)
gbt.fit(trainingData)

错误:

    Py4JJavaError: An error occurred while calling o57.fit.
    : org.apache.spark.SparkException: Job aborted due to stage failure: 
Task 0 in stage 14.0 failed 1 times, most recent failure: Lost task 0.0 in stage 14.0 (TID 14, localhost, executor driver): 
java.lang.IllegalArgumentException: 
requirement failed: GBTClassifier was given dataset with invalid label 2.0.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.

0 个答案:

没有答案