来自Spark DataFrame / RDD的前N项

时间:2018-02-13 20:28:13

标签: scala apache-spark top-n

我的要求是从数据框中获取前N个项目。

我有这个DataFrame:

val df = List(
      ("MA", "USA"),
      ("MA", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA")).toDF("value", "country")

我能够将其映射到RDD[((Int, String), Long)] colValCount: 阅读:((colIdx,value),count)

((0,CT),5)
((0,MA),2)
((0,OH),4)
((0,NY),6)
((1,USA),17)

现在我需要为每个列索引获取前2项。所以我的预期输出是:

RDD[((Int, String), Long)]

((0,CT),5)
((0,NY),6)
((1,USA),17)

我尝试在DataFrame中使用freqItems api,但速度很慢。

欢迎任何建议。

5 个答案:

答案 0 :(得分:3)

例如:

import org.apache.spark.sql.functions._

df.select(lit(0).alias("index"), $"value")
   .union(df.select(lit(1), $"country"))
   .groupBy($"index", $"value")
   .count
  .orderBy($"count".desc)
  .limit(3)
  .show

// +-----+-----+-----+
// |index|value|count|
// +-----+-----+-----+
// |    1|  USA|   17|
// |    0|   NY|    6|
// |    0|   CT|    5|
// +-----+-----+-----+

其中:

df.select(lit(0).alias("index"), $"value")
  .union(df.select(lit(1), $"country"))

创建了两列DataFrame

// +-----+-----+
// |index|value|
// +-----+-----+
// |    0|   MA|
// |    0|   MA|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    1|  USA|
// |    1|  USA|
// |    1|  USA|
// +-----+-----+

如果您想为每列特别指定两个值:

import org.apache.spark.sql.DataFrame

def topN(df: DataFrame, key: String, n: Int)  = {
   df.select(
        lit(df.columns.indexOf(key)).alias("index"), 
        col(key).alias("value"))
     .groupBy("index", "value")
     .count
     .orderBy($"count")
     .limit(n)
}

topN(df, "value", 2).union(topN(df, "country", 2)).show
// +-----+-----+-----+ 
// |index|value|count|
// +-----+-----+-----+
// |    0|   MA|    2|
// |    0|   OH|    4|
// |    1|  USA|   17|
// +-----+-----+-----+

pault said一样 - 只是“ sort()limit() 的某种组合。”

答案 1 :(得分:3)

最简单的方法 - 自然窗口函数 - 是通过编写SQL。 Spark带有SQL语法,而SQL是解决这个问题的绝佳工具。

将您的数据框注册为临时表,然后在其上进行分组和窗口。

spark.sql("""SELECT idx, value, ROW_NUMBER() OVER (PARTITION BY idx ORDER BY c DESC) as r 
             FROM (
               SELECT idx, value, COUNT(*) as c 
               FROM (SELECT 0 as idx, value FROM df UNION ALL SELECT 1, country FROM df) 
               GROUP BY idx, value) 
             HAVING r <= 2""").show()

我想看看是否有任何过程/ scala方法可以让你在没有迭代或循环的情况下执行窗口函数。我不知道Spark API中会支持它的任何内容。

顺便提一下,如果您想要包含任意数量的列,那么您可以使用列表动态生成内部部分(SELECT 0 as idx, value ... UNION ALL SELECT 1, country等)。

答案 2 :(得分:1)

鉴于你上次的RDD:

val rdd =
  sc.parallelize(
    List(
      ((0, "CT"), 5),
      ((0, "MA"), 2),
      ((0, "OH"), 4),
      ((0, "NY"), 6),
      ((1, "USA"), 17)
    ))

rdd.filter(_._1._1 == 0).sortBy(-_._2).take(2).foreach(println)
> ((0,NY),6)
> ((0,CT),5)
rdd.filter(_._1._1 == 1).sortBy(-_._2).take(2).foreach(println)
> ((1,USA),17)

我们首先获取给定列索引(.filter(_._1._1 == 0))的项目。然后我们按降序排序(.sortBy(-_._2))。最后,我们最多采用2个第一个元素(.take(2)),如果记录的nbr小于2,则只需要1个元素。

答案 3 :(得分:0)

您可以使用Sparkz中定义的辅助函数映射每个单独的分区,然后将它们组合在一起:

package sparkz.utils

import scala.reflect.ClassTag

object TopElements {
  def topN[T: ClassTag](elems: Iterable[T])(scoreFunc: T => Double, n: Int): List[T] =
    elems.foldLeft((Set.empty[(T, Double)], Double.MaxValue)) {
      case (accumulator@(topElems, minScore), elem) =>
        val score = scoreFunc(elem)
        if (topElems.size < n)
          (topElems + (elem -> score), math.min(minScore, score))
        else if (score > minScore) {
          val newTopElems = topElems - topElems.minBy(_._2) + (elem -> score)
          (newTopElems, newTopElems.map(_._2).min)
        }
        else accumulator
    }
      ._1.toList.sortBy(_._2).reverse.map(_._1)
}

来源:https://github.com/gm-spacagna/sparkz/blob/master/src/main/scala/sparkz/utils/TopN.scala

答案 4 :(得分:0)

如果您正在使用Spark SQL Dataframes,在我看来,最好的(也是更容易理解的解决方案)是像这样执行您的代码:

val test: Dataframe = df.select(col("col_name"))
test.show(5, false)

希望它可以帮助你:)