在带有管道估算器的CrossValidator中。数据如何通过管道?

时间:2018-09-09 00:00:10

标签: scala apache-spark apache-spark-mllib

看看ML Tuning: Cross-Validation,我对数据如何通过Spark管道有疑问。

让我们想象一下,我想处理类似于以下数据集的数据集:

| sqfeet | #rooms | neighbourhood | price |
|--------|--------|---------------|-------|
| 50     | 2      | GR4242        | 100   |
| 120    | 3      | GR4242        | 220   |
| 100    | 2      | FD0202        | 180   |

我在其他ML框架中所做的是:

  1. 预处理所有数据。例如,对列neighbourhood进行编码的一击。
  2. 在训练/测试中拆分数据。
  3. 在火车上使用CV执行超参数调整。
  4. 使用测试集获取模型性能的无偏度量。

但是,使用我在自定义Transformer的方法transform上方链接的代码,称为2 * CV折叠数*参数网格的叉积。

测试程序:

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}


// Don't worry, the parallelism is set to 1
object Counter {
  var timesCalled = 0
}

class CustomTransformer(override val uid: String = Identifiable.randomUID("custom"))
  extends Transformer {

  override def transform(dataset: Dataset[_]): DataFrame = {
    println(s"Times called ${Counter.timesCalled}. Dataset passed:")
    dataset.show()
    Counter.timesCalled += 1
    dataset.toDF()
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = schema
}

object TestApp {

  def main(args: Array[String]): Unit = {

    lazy val spark: SparkSession =
      SparkSession
        .builder()
        .appName("testApp")
        .master("local[1]")
        .getOrCreate()

    // Prepare training data from a list of (id, text, label) tuples.
    val training = spark.createDataFrame(Seq(
      (50f, 2, "GR4242", 100f),
      (120f, 3, "GR4242", 220f),
      (100f, 2, "FD0202", 180f)
    )).toDF("sqfeet", "#rooms", "neighbourhood", "label")

    //    val stringIndexer = new StringIndexer()
    //      .setInputCol("neighbourhood")
    //      .setOutputCol("neighbourhood_index")
    val assembler = new VectorAssembler()
      .setInputCols(Array("sqfeet", "#rooms"))
      .setOutputCol("features")
    val customTransformer = new CustomTransformer()
    val lr = new LogisticRegression()
      .setMaxIter(10)
    val pipeline = new Pipeline()
      .setStages(Array(assembler, customTransformer, lr))

    val paramGrid = new ParamGridBuilder()
      .addGrid(lr.regParam, Array(0.1, 0.01))
      .build()

    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(2)

    val cvModel = cv.fit(training)

  }

} 

输出为:

Times called 0. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 1. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 2. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 3. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 4. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 5. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 6. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 7. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 8. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

我想奇数代表训练阶段,偶数代表测试阶段。

进行一次在变压器中完成的所有昂贵的预处理计算,效率会更高吗?

0 个答案:

没有答案