在间隔连接两个Spark数据帧的同时应用UDF

时间:2018-08-27 01:52:17

标签: scala apache-spark dataframe user-defined-functions

我有一个包含三列的数据框:idindexvalue

+---+-----+-------------------+
| id|index|              value|
+---+-----+-------------------+
|  A| 1023|0.09938822262205915|
|  A| 1046| 0.3110047630613805|
|  A| 1069| 0.8486710971453512|
+---+-----+-------------------+

root
 |-- id: string (nullable = true)
 |-- index: integer (nullable = false)
 |-- value: double (nullable = false)

然后,我有另一个数据框,其中显示了每个id的期望周期:

+---+-----------+---------+
| id|start_index|end_index|
+---+-----------+---------+
|  A|       1069|     1276|
|  B|       2066|     2291|
|  B|       1616|     1841|
|  C|       3716|     3932|
+---+-----------+---------+

root
 |-- id: string (nullable = true)
 |-- start_index: integer (nullable = false)
 |-- end_index: integer (nullable = false)

我有以下三个模板

val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)

目标是,对于dfIntervals中的每一行,应用一个函数(假设它是相关的),其中该函数从value接收dfRaw列和三个模板数组,并添加三个dfIntervals列,每个列与每个模板相关。

假设: 1-模板数组的大小恰好是10。

2-{{​​1}}的{​​{1}}列中没有重复项

3-{{1}中的indexdfRaw列存在于start_index的{​​{1}}列中,并且它们之间恰好有10行。例如,end_index(dfIntervals中的第一行)将精确地生成dfIntervals

以下是生成这些数据帧的代码:

index

结果将三个列添加到dfRaw数据框中,名称分别为dfRaw.filter($"id" === "A").filter($"index" >= 1069 && $"index" <= 1276).count10import org.apache.spark.sql.functions._ val mySeed = 1000 /* Defining templates for correlation analysis*/ val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7) val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35) val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06) /* Defining raw data*/ var dfRaw = Seq( ("A", (1023 to 1603 by 23).toArray), ("B", (341 to 2300 by 25).toArray), ("C", (2756 to 3954 by 24).toArray) ).toDF("id", "index") dfRaw = dfRaw.select($"id", explode($"index") as "index").withColumn("value", rand(seed=mySeed)) /* Defining intervals*/ var dfIntervals = Seq( ("A", 1069, 1276), ("B", 2066, 2291), ("B", 1616, 1841), ("C", 3716, 3932) ).toDF("id", "start_index", "end_index")

PS:我在Scala中找不到相关函数。假设存在这样的函数(如下所示),并且我们将根据需要创建一个dfIntervals

corr_w_template1

1 个答案:

答案 0 :(得分:1)

好的。

让我们定义一个UDF函数。

出于测试目的,假设它将始终返回1。

 val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {

     1f
  })

val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
    values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
  })

然后将您的2个数据框与已定义的规则结合起来,并将value收集到称为values的1列中。另外,应用我们的orderUdf

 val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index")  <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
    .groupBy(dfIntervals("id"), dfIntervals("start_index"),  dfIntervals("end_index"))
    .agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))

最后,应用我们的udf并显示出来。

df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
    .withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
    .withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
    .show(10)

其中充满了示例代码:

import org.apache.spark.sql.functions._
  import scala.collection.JavaConverters._

  val conf = new SparkConf().setAppName("learning").setMaster("local[2]")

  val session = SparkSession.builder().config(conf).getOrCreate()



  val mySeed = 1000

  /* Defining templates for correlation analysis*/
  val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
  val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
  val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)

  val schema1 =  DataTypes.createStructType(Array(
    DataTypes.createStructField("id",DataTypes.StringType,false),
    DataTypes.createStructField("index",DataTypes.createArrayType(DataTypes.IntegerType),false)
  ))

  val schema2 =  DataTypes.createStructType(Array(
    DataTypes.createStructField("id",DataTypes.StringType,false),
    DataTypes.createStructField("start_index",DataTypes.IntegerType,false),
    DataTypes.createStructField("end_index",DataTypes.IntegerType,false)
  ))

  /* Defining raw data*/
  var dfRaw = session.createDataFrame(Seq(
    ("A", (1023 to 1603 by 23).toArray),
    ("B", (341 to 2300 by 25).toArray),
    ("C", (2756 to 3954 by 24).toArray)
  ).map(r => Row(r._1 , r._2)).asJava, schema1)

  dfRaw = dfRaw.select(dfRaw("id"), explode(dfRaw("index")) as "index")
    .withColumn("value", rand(seed=mySeed))

  /* Defining intervals*/
  var dfIntervals =  session.createDataFrame(Seq(
    ("A", 1069, 1276),
    ("B", 2066, 2291),
    ("B", 1616, 1841),
    ("C", 3716, 3932)
  ).map(r => Row(r._1 , r._2,r._3)).asJava, schema2)

  //Define udf

  val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
     1f
  })

  val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
    values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
  })


  val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index")  <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
    .groupBy(dfIntervals("id"), dfIntervals("start_index"),  dfIntervals("end_index"))
    .agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))

  df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
    .withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
    .withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
    .show(10,false)