看看ML Tuning: Cross-Validation,我对数据如何通过Spark管道有疑问。
让我们想象一下,我想处理类似于以下数据集的数据集:
| sqfeet | #rooms | neighbourhood | price |
|--------|--------|---------------|-------|
| 50 | 2 | GR4242 | 100 |
| 120 | 3 | GR4242 | 220 |
| 100 | 2 | FD0202 | 180 |
我在其他ML框架中所做的是:
neighbourhood
进行编码的一击。但是,使用我在自定义Transformer的方法transform
上方链接的代码,称为2 * CV折叠数*参数网格的叉积。
测试程序:
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
// Don't worry, the parallelism is set to 1
object Counter {
var timesCalled = 0
}
class CustomTransformer(override val uid: String = Identifiable.randomUID("custom"))
extends Transformer {
override def transform(dataset: Dataset[_]): DataFrame = {
println(s"Times called ${Counter.timesCalled}. Dataset passed:")
dataset.show()
Counter.timesCalled += 1
dataset.toDF()
}
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = schema
}
object TestApp {
def main(args: Array[String]): Unit = {
lazy val spark: SparkSession =
SparkSession
.builder()
.appName("testApp")
.master("local[1]")
.getOrCreate()
// Prepare training data from a list of (id, text, label) tuples.
val training = spark.createDataFrame(Seq(
(50f, 2, "GR4242", 100f),
(120f, 3, "GR4242", 220f),
(100f, 2, "FD0202", 180f)
)).toDF("sqfeet", "#rooms", "neighbourhood", "label")
// val stringIndexer = new StringIndexer()
// .setInputCol("neighbourhood")
// .setOutputCol("neighbourhood_index")
val assembler = new VectorAssembler()
.setInputCols(Array("sqfeet", "#rooms"))
.setOutputCol("features")
val customTransformer = new CustomTransformer()
val lr = new LogisticRegression()
.setMaxIter(10)
val pipeline = new Pipeline()
.setStages(Array(assembler, customTransformer, lr))
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2)
val cvModel = cv.fit(training)
}
}
输出为:
Times called 0. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 1. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 2. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 3. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 4. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 5. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 6. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 7. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 8. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
我想奇数代表训练阶段,偶数代表测试阶段。
进行一次在变压器中完成的所有昂贵的预处理计算,效率会更高吗?