如何在Spark Pipeline中使用RandomForest

时间:2015-08-20 03:04:12

标签: apache-spark apache-spark-mllib pipeline random-forest apache-spark-ml

我想通过网格搜索和使用spark进行交叉验证来调整我的模型。在spark中,它必须将基本模型放在管道中,office demo of pipeline使用LogistictRegression作为基础模型,它可以是新的对象。但是,RandomForest模型不能通过客户端代码,因此似乎无法在管道API中使用RandomForest。我不想重新创建一个轮子,那么有人可以给出一些建议吗? 感谢

1 个答案:

答案 0 :(得分:5)

  

但是,RandomForest模型不能是客户端代码的新模型,所以它似乎无法在管道api中使用RandomForest。

嗯,这是真的,但你只是想尝试使用错误的类。您应该使用mllib.tree.RandomForest而不是ml.classification.RandomForestClassifier。以下是基于the one from MLlib docs的示例。

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._ 

case class Record(category: String, features: Vector)

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))

val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("label")

val rf  = new RandomForestClassifier()
    .setNumTrees(3)
    .setFeatureSubsetStrategy("auto")
    .setImpurity("gini")
    .setMaxDepth(4)
    .setMaxBins(32)

val pipeline = new Pipeline()
    .setStages(Array(indexer, rf))

val model = pipeline.fit(trainDF)

model.transform(testDF)

有一件事我在这里无法弄清楚。据我所知,应该可以直接使用从LabeledPoints中提取的标签,但由于某种原因它不起作用,pipeline.fit引发IllegalArgumentExcetion

  

给RandomForestClassifier输入了无效的标签列标签,没有指定类的数量。

因此StringIndexer的丑陋技巧。在应用之后,我们获得了必需的属性({"vals":["1.0","0.0"],"type":"nominal","name":"label"}),但ml中的某些类似乎没有它就可以正常工作。