如何在Spark中对具有逗号分隔值的字符串列执行一种热编码?

时间:2020-04-20 15:50:34

标签: apache-spark

我有一个看起来像这样的数据框

val df = Seq(
(1,"a,b,c"),
(2,"b,c")
).toDF("id","page_path")
df.createOrReplaceTempView("df")

df.show()


+---+---------+
| id|page_path|
+---+---------+
|  1|    a,b,c|
|  2|      b,c|
+---+---------+

我想在此page_path列上执行一次热编码,以使输出看起来像-

output

我可以在Spark中使用一键编码吗?

3 个答案:

答案 0 :(得分:3)

可以拆分“ page_path”列,然后将值分解并进行透视:

 df
  .withColumn("splitted", split($"page_path",","))
  .withColumn("exploded", explode($"splitted"))
  .groupBy("id")
  .pivot("exploded")
  .count()
  // replace nulls with 0
  .na.fill(0)

输出:

+---+---+---+---+
|id |a  |b  |c  |
+---+---+---+---+
|1  |1  |1  |1  |
|2  |0  |1  |1  |
+---+---+---+---+

答案 1 :(得分:1)

由于您在问题中提到df.createOrReplaceTempView("df"),因此想到了给SQL版本的pasha做同样的事情。

In Databricks documenation they have mentioned many use cases with Pivot... 以下是sql爱好者的sql版本。

在这种方法中,与数据框操作方法pivot使用隐式分组相反,在sql中不需要单独的group by子句。

 val df: DataFrame = Seq((1, "a,b,c"),(2, "b,c")).toDF("id", "page_path")
  df.createOrReplaceTempView("df")
  spark.sql(
    """
      |Select * from
      |( select id, explode(split( page_path ,',')) as exploded from df )
      |pivot(count(exploded) for exploded in ('is_a','is_b','is_c')
      |)
    """.stripMargin).na.fill(0).show

结果:

+---+----+----+----+
| id|is_a|is_b|is_c|
+---+----+----+----+
|  1|   0|   0|   0|
|  2|   0|   0|   0|
+---+----+----+----+

答案 2 :(得分:0)

>>> [x for x in y for y in np.random.random((2,2))]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'y' is not defined
>>> [x for y in np.random.random((2,2)) for x in y]
[0.5656047153549479, 0.19139220091114273, 0.10286775868807774, 0.3230695608882298]