param maxCategories如何影响决策树模型

时间:2017-05-18 06:08:24

标签: scala apache-spark machine-learning apache-spark-mllib

我是新手来点燃ml,在决策树练习中,有一件事我无法向自己解释。

在数据预处理中,我做了矢量索引:

val featureIndexer = new VectorIndexer()
    .setInputCol(featureCol)
    .setOutputCol(indexedFeatureCol)
    .setMaxCategories(3)

将maxCategories设置为3.然后将其放入管道,最后获取决策树模型。当我调试模型时,我发现有一些条件如:If (feature 4 in {0.0,1.0,2.0,3.0,4.0,5.0,10.0,11.0,13.0,14.0,15.0,18.0,19.0,20.0,22.0})显然超出了maxCategories值,所以我想知道param如何影响树,并且不应该将特征4视为连续?并设置为If (feature 4 <= 19.0)

使用代码

进行更新
 def dt(spark:SparkSession, df:DataFrame, config:DataConfig) ={
  df.persist(StorageLevel.MEMORY_AND_DISK)

  val categoryCols = config.categoricalCols
  val numberCols = config.numericalCols

  val maxBins = getMaxBins(df, categoryCols :+ config.labelCol )

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  val labelIndexer = new StringIndexer()
    .setInputCol(config.labelCol)
    .setOutputCol(config.indexedLabelCol)
    .fit(df)

  //generate categorical features
  val (models, stringFeatureCols) = if(categoryCols.length > 0) getStringIndexers(df, categoryCols)
    else (Array[StringIndexerModel](), Array[String]())

  //combine all categorical features, to generate the label mapping
  val stringIndexers = Array(labelIndexer) ++ models

  //combine all features
  val featureCol = "features"
  val indexedFeatureCol = "indexedFeatures"
  val allFeatures : Array[String] = stringFeatureCols ++ numberCols

  val assembler = new VectorAssembler()
    .setInputCols(allFeatures)
    .setOutputCol(featureCol)

  //index features
  val featureIndexer = new VectorIndexer()
    .setInputCol(featureCol)
    .setOutputCol(indexedFeatureCol)
    .setMaxCategories(4) // features with > 10 distinct values are treated as continuous.

  // Split the data into training and test sets (30% held out for testing).
  val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))

  // initialize a DecisionTree model.
  val dt = new DecisionTreeClassifier()
    .setLabelCol(config.indexedLabelCol)
    .setFeaturesCol(indexedFeatureCol)
    .setMinInstancesPerNode(50)
    .setMaxDepth(10)
    .setMaxBins(maxBins)
    //.setMinInfoGain()



  // Convert indexed labels back to original labels.
  val labelConverter = new IndexToString()
    .setInputCol("prediction")
    .setOutputCol("predictedLabel")
    .setLabels(labelIndexer.labels)

  // Chain indexers and tree in a Pipeline.
  val pipeline = new Pipeline()
    .setStages(Array(labelIndexer) ++ models ++ Array(assembler, featureIndexer, dt, labelConverter))

  // Train model. This also runs the indexers.
  val model = pipeline.fit(trainingData)

  // Make predictions.
  val predictions = model.transform(testData)

  // Select example rows to display.
  predictions.select("predictedLabel", (Array(config.labelCol, config.indexedLabelCol, "prediction") ++ config.categoricalCols):_*).show(100)

  //predictions.write.mode("overwrite").json("prediction")
  val evalStr = evaluate(spark, predictions, config.labelCol, "indexedLabel", "predictedLabel", "prediction", stringIndexers)

  val idx = 3 + models.length
  val treeModel = model.stages(idx).asInstanceOf[DecisionTreeClassificationModel]
  val idx2 = idx - 1
  val vectorIndexer = model.stages(idx2).asInstanceOf[VectorIndexerModel]
  val xxx=""


  val debugStr = "Learned classification tree model:\n" + treeModel.toDebugString
  println(debugStr)

  val featueresMapping = "Features: \n"+ (categoryCols ++ numberCols).zipWithIndex.map(p => p._2 + ":" + p._1).reduce(_ + "\n" + _) + "\n"
  writeToFile(evalStr + featueresMapping + debugStr , this.getClass.getCanonicalName+"_model.txt")
}

maxBins由原始分类列的最大不同值确定:

def getMaxBins(df:DataFrame, categoricalCols:Array[String]): Int = {
val maxCol = categoricalCols.map( col => (col, df.select(df(col)).distinct().count())).reduce((l,r) => if(l._2 >= r._2) l else r)
println(s"the column ${maxCol._1} maximum # of distinct categorical features : ${maxCol._2}")
maxCol._2.toInt

}

其值为23.所以我想知道maxBin是如何影响真实类别而不是setMaxCategories?

0 个答案:

没有答案