像一个热编码器一样,如何将单个多个分类列拆分为二进制,使用Spark Scala?

时间:2019-02-25 07:09:47

标签: scala apache-spark one-hot-encoding

我的数据是这样的:

+---+---------+
| id|cate_list|
+---+---------+
|  0|  a,b,c,d|
|  1|    b,c,d|
|  2|      a,b|
|  3|        a|
|  4|a,b,c,d,e|
|  5|        e|
+---+---------+

我想要的是这样的:

-------------------------
| id|cate_list|a|b|c|d|e|
-------------------------
|  0|  a,b,c,d|1|1|1|1|0|
|  1|    b,c,d|0|1|1|1|0|
|  2|      a,b|1|1|0|0|0|
|  3|        a|1|0|0|0|0|
|  4|a,b,c,d,e|1|1|1|1|1|
|  5|        e|0|0|0|0|1|
-------------------------

我使用spark ML OneHotEncoder并尝试了许多方法,最后我得到了:

+---+---------+-------------+-------------+
| id|cate_list|categoryIndex|  categoryVec|
+---+---------+-------------+-------------+
|  0|        a|          0.0|(4,[0],[1.0])|
|  0|        b|          1.0|(4,[1],[1.0])|
|  0|        c|          2.0|(4,[2],[1.0])|
|  0|        d|          3.0|(4,[3],[1.0])|
|  1|        b|          1.0|(4,[1],[1.0])|
|  1|        c|          2.0|(4,[2],[1.0])|
|  1|        d|          3.0|(4,[3],[1.0])|
|  2|        a|          0.0|(4,[0],[1.0])|
|  2|        b|          1.0|(4,[1],[1.0])|
|  3|        a|          0.0|(4,[0],[1.0])|
|  4|        a|          0.0|(4,[0],[1.0])|
|  4|        b|          1.0|(4,[1],[1.0])|
|  4|        c|          2.0|(4,[2],[1.0])|
|  4|        d|          3.0|(4,[3],[1.0])|
|  4|        e|          4.0|    (4,[],[])|
|  5|        e|          4.0|    (4,[],[])|
+---+---------+-------------+-------------+

这不是我所需要的。当我使用python时,它非常简单,几乎两行代码可以解决此问题。 Scala太难了。

我的代码:

val df_split = df.withColumn("cate_list", explode(split($"cate_list", ",")))

val indexer = new StringIndexer()
  .setInputCol("cate_list")
  .setOutputCol("categoryIndex")
  .fit(df_split)
val indexed = indexer.transform(df_split)

val encoder = new OneHotEncoder()
  .setInputCol("categoryIndex")
  .setOutputCol("categoryVec")
val encoded = encoder.transform(indexed)

2 个答案:

答案 0 :(得分:1)

天真和直接的方法来处理问题的初始数据。

我们应该有一个udf来计算目标单元格值,并期望cate_list值和标记行名称:

val cateListContains = udf((cateList: String, item: String) => if (cateList.contains(item)) 1 else 0)

我们有一系列要提取的列名:

val targetColumns = Seq("a", "b", "c", "d", "e")

然后在源foldLeftDataFrame

val resultDf = targetColumns.foldLeft(dfSrc) {
  case (df, item) => 
    df.withColumn(item, cateListContains($"cate_list", lit(item)))
}

它准确地产生:

+---+---------+---+---+---+---+---+
|id |cate_list|a  |b  |c  |d  |e  |
+---+---------+---+---+---+---+---+
|0  |a,b,c,d  |1  |1  |1  |1  |0  |
|1  |b,c,d    |0  |1  |1  |1  |0  |
|2  |a,b      |1  |1  |0  |0  |0  |
|3  |a        |1  |0  |0  |0  |0  |
|4  |a,b,c,d,e|1  |1  |1  |1  |1  |
|5  |e        |0  |0  |0  |0  |1  |
+---+---------+---+---+---+---+---+

答案 1 :(得分:0)

您可以使用array_contains,它返回一个布尔值,然后将其强制转换为int

import org.apache.spark.sql.functions.array_contains

val aa = sc.parallelize(Array((0, "a,b,c,d"), (1, "b,c,d"), (2, "a, b"), (3, "a"), (4, "a,b,c,d,e"), (5, "e")))
var df = aa.toDF("id", "cate_list")   // create your data
val categories = Seq("a", "b", "c", "d", "e")
categories.foreach {col => 
  df = df.withColumn(col, array_contains(split($"cate_list", ","), col).cast("int"))
}
df.show()

结果:

+---+---------+---+---+---+---+---+
| id|cate_list|  a|  b|  c|  d|  e|
+---+---------+---+---+---+---+---+
|  0|  a,b,c,d|  1|  1|  1|  1|  0|
|  1|    b,c,d|  0|  1|  1|  1|  0|
|  2|     a, b|  1|  0|  0|  0|  0|
|  3|        a|  1|  0|  0|  0|  0|
|  4|a,b,c,d,e|  1|  1|  1|  1|  1|
|  5|        e|  0|  0|  0|  0|  1|
+---+---------+---+---+---+---+---+