如何在spark sql 2.1.0中的数据集<row>上获取groupby之后的所有列

时间:2017-01-05 07:05:25

标签: apache-spark apache-spark-sql

首先,我是SPARK的新手

我的数据集中有数百万条记录,我希望将groupby与name列分组,并查找具有最大年龄的名称。我得到了正确的结果,但我需要结果集中的所有列。

Dataset<Row> resultset = studentDataSet.select("*").groupBy("name").max("age");
resultset.show(1000,false);

我的结果集数据集中只有name和max(age)。

5 个答案:

答案 0 :(得分:12)

对于您的解决方案,您必须尝试不同的方法。你几乎在那里寻求解决方案,但让我帮助你理解。

Dataset<Row> resultset = studentDataSet.groupBy("name").max("age");

现在你可以做的是你可以加入resultsetstudentDataSet

Dataset<Row> joinedDS = studentDataset.join(resultset, "name");

groupBy这个问题在应用groupBy之后得到RelationalGroupedDataset所以它取决于你执行的下一个操作,如sum, min, mean, max等,然后这些操作的结果加入{{1 }}

在您的情况下,groupBy列与name的{​​{1}}相关联,因此它只返回两列但如果在max上使用age然后在“年龄”列上应用groupBy,您将获得两个第一列age,第二列是max

注意: - 代码未经过测试,请根据需要进行更改 希望这能让你清楚查询

答案 1 :(得分:2)

接受的答案并不理想,因为它需要加入。加入大型DataFrame可能会导致大型洗牌,执行速度会很慢。

让我们创建一个示例数据集并测试代码:

val df = Seq(
  ("bob", 20, "blah"),
  ("bob", 40, "blah"),
  ("karen", 21, "hi"),
  ("monica", 43, "candy"),
  ("monica", 99, "water")
).toDF("name", "age", "another_column")

此代码在大型DataFrame上应该运行得更快。

df
  .groupBy("name")
  .agg(
    max("name").as("name1_dup"), 
    max("another_column").as("another_column"),  
    max("age").as("age")
  ).drop(
    "name1_dup"
  ).show()

+------+--------------+---+
|  name|another_column|age|
+------+--------------+---+
|monica|         water| 99|
| karen|            hi| 21|
|   bob|          blah| 40|
+------+--------------+---+

答案 2 :(得分:0)

您需要记住,聚合函数会减少行的数量,因此您需要指定使用减少功能的行的年龄。如果要保留组中的所有行(警告!这可能会导致爆炸或分区倾斜),则可以将其作为列表收集。然后,您可以使用UDF(用户定义的函数)根据您的条件减少它们,在本示例中为funniness_of_requisite。然后,使用另一个UDF从单个精简行中扩展属于精简行的列。 出于这个答案的目的,我假设您希望保留具有最大funniness_of_condition的人的年龄。

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, StringType}

import scala.collection.mutable


object TestJob4 {

def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  (1, "Moe",  "Slap",  7.9, 118),
  (2, "Larry",  "Spank",  8.0, 115),
  (3, "Curly",  "Twist", 6.0, 113),
  (4, "Laurel", "Whimper", 7.53, 119),
  (5, "Hardy", "Laugh", 6.0, 18),
  (6, "Charley",  "Ignore",   9.7, 115),
  (2, "Moe",  "Spank",  6.8, 118),
  (3, "Larry",  "Twist", 6.0, 115),
  (3, "Charley",  "fall", 9.0, 115)
).toDF("id", "name", "requisite", "funniness_of_requisite", "age")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByFunniness, rawSchema)

val nameUdf = udf(extractAge, IntegerType)

val aggDf = rawDf
  .groupBy("name")
  .agg(
    count(struct("*")).as("count"),
    max(col("funniness_of_requisite")),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .withColumn("age", nameUdf($"short"))
  .drop("horizontal")

aggDf.printSchema

aggDf.show(false)
}

def reduceByFunniness= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val funniness1 = r1.getAs[Double]("funniness_of_requisite")
  val funniness2 = r2.getAs[Double]("funniness_of_requisite")

  val r3 = funniness1 match {
    case a if a >= funniness2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}

def extractAge = (x: Any) => {

val d = x.asInstanceOf[GenericRowWithSchema]

d.getAs[Int]("age")
}
 }

  d.getAs[String]("name")
}
}

这是输出

+-------+-----+---------------------------+-------------------------------+---+
|name   |count|max(funniness_of_requisite)|short                          
|age|
+-------+-----+---------------------------+-------------------------------+---+
|Hardy  |1    |6.0                        |[5, Hardy, Laugh, 6.0, 18]     
|18 |
|Moe    |2    |7.9                        |[1, Moe, Slap, 7.9, 118]       
|118|
|Curly  |1    |6.0                        |[3, Curly, Twist, 6.0, 113]    
|113|
|Larry  |2    |8.0                        |[2, Larry, Spank, 8.0, 115]    
|115|
|Laurel |1    |7.53                       |[4, Laurel, Whimper, 7.53, 119]|119|
|Charley|2    |9.7                        |[6, Charley, Ignore, 9.7, 115] |115|
+-------+-----+---------------------------+-------------------------------+---+

答案 3 :(得分:0)

您要实现的目标是

  1. 按年龄分组行
  2. 将每个组的年龄减至1行

此替代方法无需使用聚合即可实现此输出

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


object TestJob5 {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      ("Moe",  "Slap",  7.9, 118),
      ("Larry",  "Spank",  8.0, 115),
      ("Curly",  "Twist", 6.0, 113),
      ("Laurel", "Whimper", 7.53, 119),
      ("Hardy", "Laugh", 6.0, 118),
      ("Charley",  "Ignore",   9.7, 115),
      ("Moe",  "Spank",  6.8, 118),
      ("Larry",  "Twist", 6.0, 115),
      ("Charley",  "fall", 9.0, 115)
    ).toDF("name", "requisite", "funniness_of_requisite", "age")

    rawDf.show(false)
    rawDf.printSchema

    val nameWindow = Window
      .partitionBy("name")

    val aggDf = rawDf
      .withColumn("id", monotonically_increasing_id)
      .withColumn("maxFun", max("funniness_of_requisite").over(nameWindow))
      .withColumn("count", count("name").over(nameWindow))
      .withColumn("minId", min("id").over(nameWindow))
      .where(col("maxFun") === col("funniness_of_requisite") && col("minId") === col("id") )
      .drop("maxFun")
      .drop("minId")
      .drop("id")

    aggDf.printSchema

    aggDf.show(false)
  }

}

请记住,一个组的最大年龄可能会超过1行,因此您需要按逻辑选择一个。在示例中,我认为这没关系,因此我只分配一个唯一的数字来选择

答案 4 :(得分:0)

注意到随后的联接是多余的改组,并且其他一些解决方案的返回结果似乎不准确,甚至将数据集转换为数据帧,我寻求了更好的解决方案。这是我的:

case class People(name: String, age: Int, other: String)   
val df = Seq(
  People("Rob", 20, "cherry"),
  People("Rob", 55, "banana"),
  People("Rob", 40, "apple"),
  People("Ariel", 55, "fox"),
  People("Vera", 43, "zebra"),
  People("Vera", 99, "horse")
).toDS

val oldestResults = df
 .groupByKey(_.name)
 .mapGroups{ 
    case (nameKey, peopleIter) => {
        var oldestPerson = peopleIter.next  
        while(peopleIter.hasNext) {
            val nextPerson = peopleIter.next
            if(nextPerson.age > oldestPerson.age) oldestPerson = nextPerson 
        }
        oldestPerson
    }
  }    
  oldestResults.show  

以下内容产生:

+-----+---+------+
| name|age| other|
+-----+---+------+
|Ariel| 55|   fox|
|  Rob| 55|banana|
| Vera| 99| horse|
+-----+---+------+