Spark CrossValidator如何确定如何应用网格参数

时间:2017-12-07 01:03:18

标签: scala apache-spark apache-spark-mllib

来自ML Tuning https://spark.apache.org/docs/latest/ml-tuning.html的Spark文档的以下片段显然为numFeatures设置了Hashing TermFrequency,为{{regParam设置了LogisticRegression(正则化) 1}}模型:

HashingTFLogisticRegression

val hashingTF = new HashingTF()
  .setInputCol(tokenizer.getOutputCol)
  .setOutputCol("features")
val lr = new LogisticRegression()
  .setMaxIter(10)

CrossValidator ParamGridBuilder

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
val paramGrid = new ParamGridBuilder()
  .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
  .addGrid(lr.regParam, Array(0.1, 0.01))
  .build()

CrossValidator“如何知道”如何将网格值应用于各个实体?我想看看它是否是通过反思但是不清楚。

`CrossValidator可能设置的方法是:

HashingTF

  /** @group setParam */
  @Since("1.2.0")
  def setNumFeatures(value: Int): this.type = set(numFeatures, value)

逻辑回归

class LogisticRegressionModel @Since("1.3.0") (
 ..
 @Since("1.3.0") val numFeatures: Int,

这是CrossValidator上的调用:

val cv = new CrossValidator()
  .setEstimator(pipeline)
  .setEvaluator(new BinaryClassificationEvaluator)
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(2)  // Use 3+ in practice

我无法确定setEstimatorParamMaps如何正确设置HashingTFLogisticRegression值。 (注意 工作!)

这个问题的原因是我想添加一个新的Evaluator,我不确定如何将其与CrossValidator功能相匹配。

一个具体的例子:对于LDAModel:我们有调整参数kvocabSizedocConcentration:如何为这些参数设置ParamGrid

1 个答案:

答案 0 :(得分:1)

  

一个具体的例子:对于LDAModel:我们有调整参数k,vocabSize和docConcentration:如何为那些设置ParamGrid?

addGrid会获得ParamArray个兼容值。通常,它设置在EstimatorLDA)而不是Transformer ( LDAModel`)上。

要设置kdocConcentration只需按以下类型:

val lda = new LDA()

val paramGrid = new ParamGridBuilder()
 .addGrid(lda.k, Array(3, 5, 7))
 .addGrid(lda.docConcentration, Array(Array(0.1, 0.4, 0.5)))
 .build()
  

我们有调整参数(...)vocabSize

词汇大小由输入向量定义。它是无法调整的。

  

CrossValidator"如何知道"如何将网格值应用于各自的实体?

模型提供fit方法,该方法需要datasetParamMapFor example LDA

def fit(dataset: Dataset[_], paramMap: ParamMap): LDAModel
     

使用提供的参数图将单个模型拟合到输入数据。

此变种is used位于CrossValidator