我的数据是这样的:
+---+---------+
| id|cate_list|
+---+---------+
| 0| a,b,c,d|
| 1| b,c,d|
| 2| a,b|
| 3| a|
| 4|a,b,c,d,e|
| 5| e|
+---+---------+
我想要的是这样的:
-------------------------
| id|cate_list|a|b|c|d|e|
-------------------------
| 0| a,b,c,d|1|1|1|1|0|
| 1| b,c,d|0|1|1|1|0|
| 2| a,b|1|1|0|0|0|
| 3| a|1|0|0|0|0|
| 4|a,b,c,d,e|1|1|1|1|1|
| 5| e|0|0|0|0|1|
-------------------------
我使用spark ML OneHotEncoder并尝试了许多方法,最后我得到了:
+---+---------+-------------+-------------+
| id|cate_list|categoryIndex| categoryVec|
+---+---------+-------------+-------------+
| 0| a| 0.0|(4,[0],[1.0])|
| 0| b| 1.0|(4,[1],[1.0])|
| 0| c| 2.0|(4,[2],[1.0])|
| 0| d| 3.0|(4,[3],[1.0])|
| 1| b| 1.0|(4,[1],[1.0])|
| 1| c| 2.0|(4,[2],[1.0])|
| 1| d| 3.0|(4,[3],[1.0])|
| 2| a| 0.0|(4,[0],[1.0])|
| 2| b| 1.0|(4,[1],[1.0])|
| 3| a| 0.0|(4,[0],[1.0])|
| 4| a| 0.0|(4,[0],[1.0])|
| 4| b| 1.0|(4,[1],[1.0])|
| 4| c| 2.0|(4,[2],[1.0])|
| 4| d| 3.0|(4,[3],[1.0])|
| 4| e| 4.0| (4,[],[])|
| 5| e| 4.0| (4,[],[])|
+---+---------+-------------+-------------+
这不是我所需要的。当我使用python时,它非常简单,几乎两行代码可以解决此问题。 Scala太难了。
我的代码:
val df_split = df.withColumn("cate_list", explode(split($"cate_list", ",")))
val indexer = new StringIndexer()
.setInputCol("cate_list")
.setOutputCol("categoryIndex")
.fit(df_split)
val indexed = indexer.transform(df_split)
val encoder = new OneHotEncoder()
.setInputCol("categoryIndex")
.setOutputCol("categoryVec")
val encoded = encoder.transform(indexed)
答案 0 :(得分:1)
天真和直接的方法来处理问题的初始数据。
我们应该有一个udf
来计算目标单元格值,并期望cate_list
值和标记行名称:
val cateListContains = udf((cateList: String, item: String) => if (cateList.contains(item)) 1 else 0)
我们有一系列要提取的列名:
val targetColumns = Seq("a", "b", "c", "d", "e")
然后在源foldLeft
上DataFrame
:
val resultDf = targetColumns.foldLeft(dfSrc) {
case (df, item) =>
df.withColumn(item, cateListContains($"cate_list", lit(item)))
}
它准确地产生:
+---+---------+---+---+---+---+---+
|id |cate_list|a |b |c |d |e |
+---+---------+---+---+---+---+---+
|0 |a,b,c,d |1 |1 |1 |1 |0 |
|1 |b,c,d |0 |1 |1 |1 |0 |
|2 |a,b |1 |1 |0 |0 |0 |
|3 |a |1 |0 |0 |0 |0 |
|4 |a,b,c,d,e|1 |1 |1 |1 |1 |
|5 |e |0 |0 |0 |0 |1 |
+---+---------+---+---+---+---+---+
答案 1 :(得分:0)
您可以使用array_contains
,它返回一个布尔值,然后将其强制转换为int
。
import org.apache.spark.sql.functions.array_contains
val aa = sc.parallelize(Array((0, "a,b,c,d"), (1, "b,c,d"), (2, "a, b"), (3, "a"), (4, "a,b,c,d,e"), (5, "e")))
var df = aa.toDF("id", "cate_list") // create your data
val categories = Seq("a", "b", "c", "d", "e")
categories.foreach {col =>
df = df.withColumn(col, array_contains(split($"cate_list", ","), col).cast("int"))
}
df.show()
结果:
+---+---------+---+---+---+---+---+
| id|cate_list| a| b| c| d| e|
+---+---------+---+---+---+---+---+
| 0| a,b,c,d| 1| 1| 1| 1| 0|
| 1| b,c,d| 0| 1| 1| 1| 0|
| 2| a, b| 1| 0| 0| 0| 0|
| 3| a| 1| 0| 0| 0| 0|
| 4|a,b,c,d,e| 1| 1| 1| 1| 1|
| 5| e| 0| 0| 0| 0| 1|
+---+---------+---+---+---+---+---+