将一行spark数据集分解为多行,并使用flatmap添加列

时间:2016-12-26 01:10:59

标签: scala apache-spark explode flatmap

我有一个带有以下架构的DataFrame:

root
 |-- journal: string (nullable = true)
 |-- topicDistribution: vector (nullable = true)

topicDistribution字段是双精度矢量:[0.1,0.2 0.15 ...]

我想要的是将每行分成几行以获得以下架构:

root
 |-- journal: string
 |-- topic-prob: double // this is the value from the vector
 |-- topic-id : integer // this is the index of the value from the vector

为了澄清,我已经创建了一个案例类:

case class JournalDis(journal: String, topic_id: Integer, prob: Double)

我设法使用dataset.explode以非常尴尬的方式实现了这一目标:

val df1 = df.explode("topicDistribution", "topic") {
    topics: DenseVector => topics.toArray.zipWithIndex
}.select("journal", "topic")
val df2 = df1.withColumn("topic_id", df1("topic").getItem("_2")).withColumn("topic_prob", df1("topic").getItem("_1")).drop(df1("topic"))

不推荐使用dataset.explode 。我想知道如何使用flatmap方法实现这一目标?

1 个答案:

答案 0 :(得分:1)

未经测试但应该有效:

import spark.implicits._
import org.apache.spark.ml.linalg.Vector

df.as[(String, Vector)].flatMap { 
  case (j, ps) => ps.toArray.zipWithIndex.map { 
    case (p, ti) => JournalDis(j, ti, p)
  }
}