决策树分类器中的问题

时间:2018-03-29 20:14:45

标签: java apache-spark apache-spark-ml

我正在尝试运行Decision Tree分类器,标签有双重架构,值从-20到+20

import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import java.io.File`

     val dtModelPath = s"file:///home/parv/spark/examples/src/main/scala/org/apache/spark/examples/ml/ dtModel"

     val dtModel= { 
     val dtGridSearch = for (   
     dtImpurity<- Array("entropy", "gini");    
     dtDepth<- Array(3, 5))    
     yield {
     println(s"Training decision tree: impurity $dtImpurity,depth: $dtDepth")
     val dtModel = new DecisionTreeClassifier()
     .setFeaturesCol(idf.getOutputCol)  
     .setLabelCol("value")
     .setImpurity(dtImpurity)         
     .setMaxDepth(dtDepth)     
     .setMaxBins(10)          
     .setSeed(42)          
     .setCacheNodeIds(true)          
     .fit(trainData)
     val dtPrediction = dtModel.transform(testData)      
     val dtAUC = new BinaryClassificationEvaluator().setLabelCol("value").evaluate(dtPrediction)      
     println(s" DT AUC on test data: $dtAUC")      
     ((dtImpurity, dtDepth), dtModel, dtAUC)
     }    
     println(dtGridSearch.sortBy(-_._3).take(5).mkString("\n")) 
     val bestModel = dtGridSearch.sortBy(-_._3).head._2
     bestModel.write.overwrite.save(dtModelPath)
     bestModel
     }

我收到错误

  下雨决策树:杂质熵,深度:3 [阶段   31346:============&GT; (47 + [阶段   31346:===============&GT; (61 + [阶段   31346:======================&GT; (87 + [阶段   31346:============================&GT; (111 + [阶段   31346:==================================&GT; (135 + [阶段   31346:==========================================&GT; (166 + [阶段   31346:================================================ &GT; (192 +                                                                          18/03/30 01:06:18 WARN执行官:1块锁未被释放   TID = 63510:[rdd_62747_0] 18/03/30 01:06:18错误执行人:异常   在第31353.0阶段的任务7.0中(TID 63518)   java.lang.IllegalArgumentException:要求失败:分类器是   给定具有无效标签-6.0的数据集。标签必须是整数   范围[0,1,...,44),其中numClasses = 44。在   scala.Predef $ .require(Predef.scala:224)

1 个答案:

答案 0 :(得分:0)

您似乎为分类器提供了无效标签。 它说Classifier was given dataset with invalid label -6.0. Labels must be integers in range [0, 1, ..., 44)

我会检查标签,例如

df.select($"labels").distinct.show(100)
df.filter($"labels" < 0).show()