在Spark数据帧中为每个组创建索引

时间:2017-03-03 20:39:55

标签: apache-spark apache-spark-sql

我在Spark中有一个包含2列group_idvalue的数据框,其中value是双精度数。我想根据group_id对数据进行分组,按value对每个组进行排序,然后添加第三列index,表示value在排序中的位置该组的价值观。

例如,考虑以下输入数据:

+--------+-----+
|group_id|value|
+--------+-----+
|1       |1.3  |
|2       |0.8  |
|1       |3.4  |
|1       |-1.7 |
|2       |2.3  |
|2       |5.9  |
|1       |2.7  |
|1       |0.0  |
+--------+-----+

输出将类似于

+--------+-----+-----+
|group_id|value|index|
+--------+-----+-----+
|1       |-1.7 |1    |
|1       |0.0  |2    |
|1       |1.3  |3    |
|1       |2.7  |4    |
|1       |3.4  |5    |
|2       |0.8  |1    |
|2       |2.3  |2    |
|2       |5.9  |3    |
+--------+-----+-----+

如果索引从0开始并且排序是升序还是降序,则不重要。

作为后续行动,请考虑原始数据中存在第三列extra的情况,该列会为某些(group_id, value)组合带来多个值。一个例子是:

+--------+-----+-----+
|group_id|value|extra|
+--------+-----+-----+
|1       |1.3  |1    |
|1       |1.3  |2    |
|2       |0.8  |1    |
|1       |3.4  |1    |
|1       |3.4  |2    |
|1       |3.4  |3    |
|1       |-1.7 |1    |
|2       |2.3  |1    |
|2       |5.9  |1    |
|1       |2.7  |1    |
|1       |0.0  |1    |
+--------+-----+-----+

是否可以添加index列,以便extra列不被考虑但仍保留?这种情况下的输出是

+--------+-----+-----+-----+
|group_id|value|extra|index|
+--------+-----+-----+-----+
|1       |-1.7 |1    |1    |
|1       |0.0  |1    |2    |
|1       |1.3  |1    |3    |
|1       |1.3  |2    |3    |
|1       |2.7  |1    |4    |
|1       |3.4  |1    |5    |
|1       |3.4  |2    |5    |
|1       |3.4  |3    |5    |
|2       |0.8  |1    |1    |
|2       |2.3  |1    |2    |
|2       |5.9  |1    |3    |
+--------+-----+-----+-----+

我知道可以通过复制数据,删除extra列来实现此目的

  1. 复制数据
  2. 删除extra
  3. 执行distinct操作,这将导致原始示例中的数据
  4. 使用原始解决方案计算index
  5. 使用第二个示例
  6. 中的数据加入结果

    然而,这将涉及大量额外的计算和开销。

1 个答案:

答案 0 :(得分:7)

您可以使用Window函数创建基于value的排名列,按group_id分区:

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, dense_rank
# Define window
window = Window.partitionBy(df['group_id']).orderBy(df['value'])
# Create column
df.select('*', rank().over(window).alias('index')).show()
+--------+-----+-----+
|group_id|value|index|
+--------+-----+-----+
|       1| -1.7|    1|
|       1|  0.0|    2|
|       1|  1.3|    3|
|       1|  2.7|    4|
|       1|  3.4|    5|
|       2|  0.8|    1|
|       2|  2.3|    2|
|       2|  5.9|    3|
+--------+-----+-----+

因为,您首先选择'*',所以您也可以使用上面的代码保留所有其他变量。但是,您的第二个示例显示您正在查找函数dense_rank(),该函数作为排名列提供无间隙:

df.select('*', dense_rank().over(window).alias('index'))