在pyspark中的每个DataFrame组中检索前n个

时间:2016-07-15 13:49:34

标签: python apache-spark dataframe pyspark apache-spark-sql

pyspark中有一个DataFrame,数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我期望在每个组中使用相同的user_id返回2条记录,这些记录需要得分最高。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我是pyspark的新手,有人能给我一个代码片段或门户网站来解决这个问题的相关文档吗?非常感谢!

6 个答案:

答案 0 :(得分:46)

我认为您需要使用window functions根据user_idscore获得每行的排名,然后过滤您的结果,只保留前两个值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

一般来说,官方programming guide是开始学习Spark的好地方。

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

答案 1 :(得分:18)

如果在获得排名平等时使用row_number而不是rank,则Top-n更准确:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()
  对于Jupyter笔记本,请注意limit(20).toPandas()技巧,而不是show(),以获得更好的格式。

答案 2 :(得分:2)

我知道有人问这个问题>>> col1 = ["2019-01-01 03:00:00", "2019-01-01 03:01:00", "2019-01-01 03:02:00"] >>> col2 = ["2019-01-01 02:59:00", "2019-01-01 03:00:00", "2019-01-01 03:01:00", "2019-01-01 03:02:00", "2019-01-01 03:03:00"] >>> ind = [] >>> for element in col1: if element in col2: ind.append(element) >>> print(ind) ['2019-01-01 03:00:00', '2019-01-01 03:01:00', '2019-01-01 03:02:00'] ,而我正在pyspark中寻找类似的答案,即

  

在Scala中检索DataFrame每组中的前n个值

这是@mtoto答案的Scala版本。

scala

可以找到更多示例here

答案 3 :(得分:2)

这是另一种没有窗口函数的解决方案,可以从 pySpark DataFrame 中获取前 N 条记录。

# Import Libraries
from pyspark.sql.functions import col

# Sample Data
rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)

# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)

# Display DataFrame
sorted_df.show()

输出

+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2|    2|
| user_2| object_2|    2|
| user_1| object_1|    3|
| user_2| object_1|    5|
| user_2| object_2|    6|
+-------+---------+-----+

如果您对 Spark 中的更多窗口函数感兴趣,可以参考我的一篇博客:https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86

答案 4 :(得分:1)

使用Python 3和Spark 2.4

from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) 

答案 5 :(得分:0)

使用ROW_NUMBER()函数在PYSPARK SQL查询中查找第N个最大值:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

N是列中要求的第n个最高值

输出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

查询将返回N个最大值