spark scala dataframe合并多个数据帧

时间:2017-06-20 22:03:06

标签: scala apache-spark dataframe merge

我有三个文件,

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ab|  ac|
## |  2| bb|  bc|  bd|
## +---+----+----+---+

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ab|  ad|
## |  2| bb|  bb|  bd|
## +---+----+----+---+

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ac|  ad|
## |  2| bb|  bc|  bd|
## +---+----+----+---+

我需要比较前两个文件(我正在读取数据帧)并仅识别更改然后与第三个文件合并,因此我的输出应该是,

## +---+----+----+---+
## |pk1|pk2|val1|val2|
## +---+----+----+---+
## |  1| aa|  ac|  ad|
## |  2| bb|  bb|  bd|
## +---+----+----+---+

如何只选择更改的列?并更新另一个数据帧?

2 个答案:

答案 0 :(得分:1)

我还不能发表评论,所以我会尝试解决这个问题。可能还需要修改。据我所知,您正在寻找最后一个独特的变化。所以Val1有{ab - > ab - > ac,bc - > bb - > bc}所以最终结果是{ac,bb},因为最后一个文件的bc位于第一个文件中,因此不是唯一的。如果是这种情况,则处理此问题的最佳方法是创建一个集合并从集合中获取最后一个值。我将使用udf来完成这项工作

所以从你的例子:

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ac"),(2,"bb","bc","bd"))).toDF("pk1","pk2","val1","val2")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ad"),(2,"bb","bb","bd"))).toDF("pk1","pk2","val1","val2")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","ac","ad"),(2,"bb","bc","bd"))).toDF("pk1","pk2","val1","val2") 

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.UserDefinedFunction
import sqlContext.implicits._

def getChange: UserDefinedFunction = 
    udf((a: String, b: String, c: String) => Set(a,b,c).last)

df1
.join(df2,df1("pk1")===df2("pk1") && df1("pk2")===df2("pk2"), "inner")
.join(df3,df1("pk1")===df3("pk1") && df1("pk2")===df3("pk2"), "inner")
.select(df1("pk1"),df1("pk2"),
  df1("val1").as("df1Val1"),df2("val1").as("df2Val1"),df3("val1").as("df3Val1"),
  df1("val2").as("df1Val2"),df2("val2").as("df2Val2"),df3("val2").as("df3Val2"))
  .withColumn("val1",getChange($"df1Val1",$"df2Val1",$"df3Val1"))
  .withColumn("val2",getChange($"df1Val2",$"df2Val2",$"df3Val2"))
  .select($"pk1",$"pk2",$"val1",$"val2")
  .orderBy($"pk1")
.show(false)

这会产生:

+---+---+----+----+
|pk1|pk2|val1|val2|
+---+---+----+----+
|1  |aa |ac  |ad  |
|2  |bb |bb  |bd  |
+---+---+----+----+

显然,如果你使用更多的列或更多的数据帧,那么写出来会变得有点麻烦,但这应该可以为你的例子做点招数

修改
这用于向混合添加更多列。正如我所说的那样,它有点麻烦。这将迭代每列,直到没有剩下。

require(df1.columns.sameElements(df2.columns) && df1.columns.sameElements(df3.columns),"DF Columns do not match") //this is a check so may not be needed

val cols: Array[String] = df1.columns

def getChange: UserDefinedFunction = udf((a: String, b: String, c: String) => Set(a,b,c).last)

def createFrame(cols: Array[String], df1: DataFrame, df2: DataFrame, df3:DataFrame): scala.collection.mutable.ListBuffer[DataFrame] = {

val list: scala.collection.mutable.ListBuffer[DataFrame] = new scala.collection.mutable.ListBuffer[DataFrame]()
val keys = cols.slice(0,2) //get the keys
val columns = cols.slice(2, cols.length).toSeq //get the columns to use

  def helper(columns: Seq[String]): scala.collection.mutable.ListBuffer[DataFrame] = {
    if(columns.isEmpty) list
    else {
      list += df1
        .join(df2, df1.col(keys(0)) === df2.col(keys(0)) && df1.col(keys(1)) === df2.col(keys(1)), "inner")
        .join(df3, df1.col(keys(0)) === df3.col(keys(0)) && df1.col(keys(1)) === df3.col(keys(1)), "inner")
        .select(df1.col(keys(0)), df1.col(keys(1)),
        getChange(df1.col(columns.head), df2.col(columns.head), df3.col(columns.head)).as(columns.head))

      helper(columns.tail) //use tail recursion
  }
}
  helper(columns)
}

val list: scala.collection.mutable.ListBuffer[DataFrame] = createFrame(cols, df1, df2, df3)

list.reduce((a,b) =>
  a
    .join(b,a(cols.head)===b(cols.head) && a(cols(1))===b(cols(1)),"inner")
    .drop(b(cols.head))
    .drop(b(cols(1))))
.select(cols.head, cols.tail: _*)
.orderBy(cols.head)
.show

一个包含3个值列的示例,然后将这些值传递到上面的代码中:

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ac","ad"),(2,"bb","bc","bd","bc"))).toDF("pk1","pk2","val1","val2","val3")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","ab","ad","ae"),(2,"bb","bb","bd","bf"))).toDF("pk1","pk2","val1","val2","val3")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","ac","ad","ae"),(2,"bb","bc","bd","bg"))).toDF("pk1","pk2","val1","val2","val3")

产生以下数据帧:

运行上面的代码会产生:

//output
+---+---+----+----+----+
|pk1|pk2|val1|val2|val3|
+---+---+----+----+----+
|  1| aa|  ac|  ad|  ae|
|  2| bb|  bb|  bd|  bg|
+---+---+----+----+----+

也许有一种更有效的方法可以做到这一点,但这不是我的头脑。

<强> EDIT2

要使用任意数量的键执行此操作,您可以执行以下操作。您需要在开始时定义键数。这也可以清理。我已经使用4/5键,但你也应该运行一些测试,但它应该工作:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.UserDefinedFunction

val df1: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ab","ac","ad"),(2,"bb","d","e","bc","bd","bc"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")
val df2: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ab","ad","ae"),(2,"bb","d","e","bb","bd","bf"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")
val df3: DataFrame = sparkContext.parallelize(Seq((1,"aa","c","d","ac","ad","ae"),(2,"bb","d","e","bc","bd","bg"))).toDF("pk1","pk2","pk3","pk4","val1","val2","val3")

require(df1.columns.sameElements(df2.columns) && df1.columns.sameElements(df3.columns),"DF Columns do not match")

val cols: Array[String] = df1.columns

def getChange: UserDefinedFunction = udf((a: String, b: String, c: String) => Set(a,b,c).last)

def createFrame(cols: Array[String], df1: DataFrame, df2: DataFrame, df3:DataFrame): scala.collection.mutable.ListBuffer[DataFrame] = {

val list: scala.collection.mutable.ListBuffer[DataFrame] = new scala.collection.mutable.ListBuffer[DataFrame]()
val keys = cols.slice(0,4)//get the keys
val columns = cols.slice(4, cols.length).toSeq //get the columns to use

def helper(columns: Seq[String]): scala.collection.mutable.ListBuffer[DataFrame] = {

  if(columns.isEmpty) list
  else {
    list += df1
      .join(df2, Seq(keys :_*), "inner")
      .join(df3, Seq(keys :_*), "inner")
      .withColumn(columns.head + "Out", getChange(df1.col(columns.head), df2.col(columns.head), df3.col(columns.head)))
      .select(col(columns.head + "Out").as(columns.head) +: keys.map(x => df1.col(x)) : _*)

    helper(columns.tail)
  }
}

helper(columns)
}

val list: scala.collection.mutable.ListBuffer[DataFrame] = createFrame(cols, df1, df2, df3)
list.foreach(a => a.show(false))
val keys=cols.slice(0,4)

list.reduce((a,b) =>
  a.alias("a").join(b.alias("b"),Seq(keys :_*),"inner")
  .select("a.*","b." + b.columns.head))
  .orderBy(cols.head)
  .show(false)

这会产生:

+---+---+---+---+----+----+----+
|pk1|pk2|pk3|pk4|val1|val2|val3|
+---+---+---+---+----+----+----+
|1  |aa |c  |d  |ac  |ad  |ae  |
|2  |bb |d  |e  |bb  |bd  |bg  |
+---+---+---+---+----+----+----+

答案 1 :(得分:0)

我也可以通过将数据帧创建为临时视图然后执行选择case语句来完成此操作。像这样,

df1.createTempView("df1")
df2.createTempView("df2")
df3.createTempView("df3")

select case when df1.val1=df2.val1 and df1.val1<>df3.val1 then df3.val1 end

这要快得多。