如何使用mapGroups在scala spark中的groupby之后计算列中的不同值

时间:2018-10-02 18:31:52

标签: scala apache-spark

我是scala spark的新手。我有一个文本文件数据

001,delhi,india
002,chennai,india
003,hyderabad,india
004,newyork,us
005,chicago,us
006,lasvegas,us
007,seattle,us

我想计算每个国家/地区中不同城市的数量,因此我应用了groupBy和mapGroups。我不确定如何计算mapGroups中的值。请在下面找到我的代码

val ds1 = sparkSession.read.textFile("samplefile.txt").map(x => x.split(","))
  val ds2 = ds1.groupByKey(x => x(2)).mapGroups{case(k,iter) => (k,iter.map(x => x(2)).toArray)}

请帮助我提供语法。我知道可以通过spark-sql轻松完成,但是我想通过scala完成

2 个答案:

答案 0 :(得分:0)

正确的方法是将df用作源数据帧,

import org.apache.spark.sql.functions._

val df: DataFrame = ???

val result = df.groupBy("country col name").agg(countDistinct("city column name").alias("city_count"))

希望这会有所帮助。

答案 1 :(得分:0)

要计算每个国家/地区的不同城市,可以将按国家/地区列表映射到city数组,并计算不同城市的数量:

val ds1 = spark.read.textFile("/path/to/textfile").map(_.split(","))
val ds2 = ds1.
  groupByKey(_(2)).mapGroups{ case (k, iter) =>
    (k, iter.map(_(1)).toList.distinct.size)
  }

[更新]

要计算每个国家/地区的平均值,例如,从第4个数字列中进行以下操作:

val ds3 = ds1.
  groupByKey(_(2)).mapGroups{ case (k, iter) =>
    val numList = iter.map(_(3).toDouble).toList
    (k, numList.sum / numList.size)
  }

如果您需要各种数字聚合,我认为Spark DataFrame API将是一种更有效的工具(例如,它内置了avg())。