使用数据框进行Spark会话化

时间:2019-02-04 09:06:06

标签: scala apache-spark apache-spark-sql user-defined-functions

我想在spark数据帧上执行clickstream会话化。让我们以以下模式加载具有来自多个会话的事件的数据框- enter image description here

我想像这样聚合(缝合)会话- enter image description here

我已经研究了UDAF和Window函数,但是不明白如何在特定的用例中使用它们。我知道按会话ID对数据进行分区会将整个会话数据放在一个分区中,但是如何将它们聚合?

该想法是将特定于每个会话的所有事件汇总为单个输出记录。

1 个答案:

答案 0 :(得分:2)

您可以使用collect_set:

 def process(implicit spark: SparkSession) = {
      import spark._

      import org.apache.spark.sql.functions.{ concat, col, collect_set }

      val seq = Seq(Row(1, 1, "startTime=1549270909"), Row(1, 1, "endTime=1549270913"))

      val rdd = spark.sparkContext.parallelize(seq)

      val df1 = spark.createDataFrame(rdd, StructType(List(StructField("sessionId", IntegerType, false), StructField("userId", IntegerType, false), StructField("session", StringType, false))))

      df1.groupBy("sessionId").agg(collect_set("session"))
    }
  }

那给你:

+---------+------------------------------------------+
|sessionId|collect_set(session)                      |
+---------+------------------------------------------+
|1        |[startTime=1549270909, endTime=1549270913]|
+---------+------------------------------------------+

作为输出。

如果您需要更复杂的逻辑,可以将其包含在以下UDAF中:

  class YourComplexLogicStrings extends UserDefinedAggregateFunction {
    override def inputSchema: StructType = StructType(StructField("input", StringType) :: Nil)

    override def bufferSchema: StructType = StructType(StructField("pair", StringType) :: Nil)

    override def dataType: DataType = StringType

    override def deterministic: Boolean = true

    override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = ""

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      val b = buffer.getAs[String](0)
      val i = input.getAs[String](0)
      buffer(0) = { if(b.isEmpty) b + i else b + " + " + i }
    }

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      val b1 = buffer1.getAs[String](0)
      val b2 = buffer2.getAs[String](0)
      if(!b1.isEmpty)
        buffer1(0) = (b1) ++ "," ++ (b2)
      else
        buffer1(0) = b2
    }

    override def evaluate(buffer: Row): Any = {
      val yourString = buffer.getAs[String](0)
      // Compute your logic and return another String
      yourString
    }
  }



def process0(implicit spark: SparkSession) = {

  import org.apache.spark.sql.functions.{ concat, col, collect_set }


  val agg0 = new YourComplexLogicStrings()

  val seq = Seq(Row(1, 1, "startTime=1549270909"), Row(1, 1, "endTime=1549270913"))

  val rdd = spark.sparkContext.parallelize(seq)

  val df1 = spark.createDataFrame(rdd, StructType(List(StructField("sessionId", IntegerType, false), StructField("userId", IntegerType, false), StructField("session", StringType, false))))

  df1.groupBy("sessionId").agg(agg0(col("session")))
}

它给出:

+---------+---------------------------------------+
|sessionId|yourcomplexlogicstrings(session)       |
+---------+---------------------------------------+
|1        |startTime=1549270909,endTime=1549270913|
+---------+---------------------------------------+

请注意,如果您想避免使用UDAF,则可以直接使用spark sql函数包含非常复杂的逻辑。