在Spark中将连续变量转换为分类变量

时间:2017-12-04 17:32:10

标签: scala apache-spark recode

我正在尝试将一些连续变量转换为分类变量,以便将一些ML算法应用于它们,并且我想制作类似于6:00到12:00的类别 - > "上午"或日期格式如ddMM to" Summer"管他呢。 这些变量已经转换为整数。和R中的recode函数一样,我想。

+----------+
|CRSDepTime|
+----------+
|       745|
|      1053|
|      1915|
|      1755|
|       832|
|       630|
|       820|
|       945|
|      1245|
|      1645|
|       620|
|      1125|
|      2045|
|      1340|
|      1540|
|       730|
|      1145|
|       525|
|       630|
|      1520|
+----------+

我用这句话解决了这个问题!!

df = df.withColumn("Season", when(df("Month") >= 12 and df("Month") <=3, "Fall")
  .when(df("Month") >= 4 and df("Month") <= 6, "Spring")
  .when(df("Month") >= 7 and df("Month") <= 9, "Summer").otherwise("Autumm"))

1 个答案:

答案 0 :(得分:3)

有两个Transformers可用于将连续变量转换为分类变量:

  • Bucketizer
  • QuantileDiscretizer

Bucketizer进行拆分,因此可以在此处使用:

import org.apache.spark.ml.feature._

val df = Seq(
  745, 1053, 1915, 1755, 832, 630, 820, 945,
  1245, 1645, 620, 1125, 2045, 1340, 1540, 730,
  1145, 525, 630, 1520
).toDF("CRSDepTime")

val bucketizer = new Bucketizer()
  .setInputCol("CRSDepTime")
  .setOutputCol("bucketedFeatures")
  .setSplits(Array(0, 600, 1200, 1800, 2400))

// +----------+----------------+
// |CRSDepTime|bucketedFeatures|
// +----------+----------------+
// |       745|             1.0|
// |      1053|             1.0|
// |      1915|             3.0|
// |      1755|             2.0|
// |       832|             1.0|
// |       630|             1.0|
// |       820|             1.0|
// |       945|             1.0|
// |      1245|             2.0|
// |      1645|             2.0|
// +----------+----------------+
// only showing top 10 rows

通常会与OneHotEncoder

结合使用
import org.apache.spark.ml.Pipeline

val encoder = new OneHotEncoder()
  .setInputCol(bucketizer.getOutputCol)
  .setOutputCol("CRSDepTimeencoded")

val pipeline = new Pipeline().setStages(Array(bucketizer, encoder))

pipeline.fit(df).transform(df).show(10)

// +----------+----------------+-----------------+
// |CRSDepTime|bucketedFeatures|CRSDepTimeencoded|
// +----------+----------------+-----------------+
// |       745|             1.0|    (3,[1],[1.0])|
// |      1053|             1.0|    (3,[1],[1.0])|
// |      1915|             3.0|        (3,[],[])|
// |      1755|             2.0|    (3,[2],[1.0])|
// |       832|             1.0|    (3,[1],[1.0])|
// |       630|             1.0|    (3,[1],[1.0])|
// |       820|             1.0|    (3,[1],[1.0])|
// |       945|             1.0|    (3,[1],[1.0])|
// |      1245|             2.0|    (3,[2],[1.0])|
// |      1645|             2.0|    (3,[2],[1.0])|
// +----------+----------------+-----------------+
// only showing top 10 rows