在Apache火花中进行交叉验证。如何创建参数网格?

时间:2019-04-17 16:21:35

标签: apache-spark apache-spark-mllib cross-validation

我试图设置一个ParamGrid以便以后使用交叉验证。但是我找不到关于输入参数的任何解释。

创建管道之后,我试图创建一个参数网格,但是由于我不了解,因此预期会出现哪些条目。

//creating my pipeline with indexer, oneHotEncoder, creating the feature vector and applying linear regression on it
val IndexedList = StringList.flatMap{ name =>

    val indexer = new StringIndexer().setInputCol(name).setOutputCol(name + "Index")

    val encoder = new OneHotEncoderEstimator()
        .setInputCols(Array(name+ "Index"))
        .setOutputCols(Array(name + "vec"))

  Array(indexer,encoder)
  }

 val features = new VectorAssembler().setInputCols(Array("Modellvec", "KM", "Hubraum", "Fuelvec","Farbevec","Typevec","F1","F2","F3","F4","F5","F6","F7","F8")).setOutputCol("Features2")
  val linReg = new LinearRegression()//.setFeaturesCol(features2.getOutputCol).setLabelCol("Preis")

//creates the Array of stages
  val IndexedList3 = (IndexedList :+ features :+ linReg).toArray[PipelineStage]

  val pipeline2 = new Pipeline()

//This grid should be created in order to apply cross-validation
  val <b>pipeline_grid</b> = new ParamGridBuilder()
      .baseOn(pipeline2.stages -> IndexedList3)
      .addGrid(linReg.regParam, Array(10,15,20,25,30,35,40,45,50,55,60,65,70,75) ).build()

第一部分单独运行时效果很好。

问题是,我不了解,“ addGrid”中的数组应该是什么样子(或者我应该如何选择值),以及为什么这是一个问题,因为linReg.regParam是DoubleParam类型的,因为addGrid IS在此类型上定义。

在我所看到的大多数示例中,此数组似乎无处不在。有人可以向我解释它的来源吗?

0 个答案:

没有答案