pyspark相当于pandas groupby('col1')。col2.head()

时间:2018-05-09 11:23:54

标签: group-by pyspark spark-dataframe sample

我有一个Spark Dataframe,对于每组具有给定列值(col1)的行,我想获取(col2)中值的样本。 col1的每个可能值的行数可能差异很大,所以我只是寻找每种类型的一组数字,比如说10。

可能有更好的方法可以做到这一点,但自然的方法似乎是df.groupby('col1')

在pandas中,我可以做df.groupby('col1')。col2.head()

我知道spark数据帧不是pandas数据帧,但这是一个很好的类比。

我想我可以将所有col1类型作为过滤器循环,但这看起来非常糟糕。

有关如何做到这一点的任何想法?谢谢。

1 个答案:

答案 0 :(得分:1)

让我创建一个包含两列的示例Spark数据帧。

df = SparkSQLContext.createDataFrame([[1, 'r1'],
 [1, 'r2'],
 [1, 'r2'],
 [2, 'r1'],
 [3, 'r1'],
 [3, 'r2'],
 [4, 'r1'],
 [5, 'r1'],
 [5, 'r2'],
 [5, 'r1']], schema=['col1', 'col2'])
df.show()

+----+----+
|col1|col2|
+----+----+
|   1|  r1|
|   1|  r2|
|   1|  r2|
|   2|  r1|
|   3|  r1|
|   3|  r2|
|   4|  r1|
|   5|  r1|
|   5|  r2|
|   5|  r1|
+----+----+

按col1分组后,我们得到GroupedData对象(而不是Spark Dataframe)。您可以使用聚合函数,如min,max,average。但是获得一个头()有点棘手。我们需要将GroupedData对象转换回Spark Dataframe。这可以使用pyspark collect_list()聚合函数来完成。

from pyspark.sql import functions
df1 = df.groupBy(['col1']).agg(functions.collect_list("col2")).show(n=3)

输出是:

+----+------------------+
|col1|collect_list(col2)|
+----+------------------+
|   5|      [r1, r2, r1]|
|   1|      [r1, r2, r2]|
|   3|          [r1, r2]|
+----+------------------+
only showing top 3 rows