Apache Spark:数据帧中行值列表的百分位数

时间:2017-10-03 00:15:29

标签: list dataframe pyspark

我有一个带有一组计算列的Apache Spark数据帧。对于数据框中的每一行(大约2000),我希望获取10列的行值,并找到第11列相对于其他10的最接近的值。

我想我会取这些行值并将其转换为列表然后使用abs值计算来确定最接近的值。

但我仍然坚持如何将行值转换为列表。我已经采用了一个列并使用collect_list将这些值转换为列表但不确定当列表来自单行和多列时如何处理。

1 个答案:

答案 0 :(得分:1)

您应该explode列,以便线性化您的计算。

让我们创建一个示例数据框:

import numpy as np
np.random.seed(0)
df = sc.parallelize([np.random.randint(0, 10, 11).tolist() for _ in range(20)])\
    .toDF(["col" + str(i) for i in range(1, 12)])
df.show()
    +----+----+----+----+----+----+----+----+----+-----+-----+
    |col1|col2|col3|col4|col5|col6|col7|col8|col9|col10|col11|
    +----+----+----+----+----+----+----+----+----+-----+-----+
    |   5|   0|   3|   3|   7|   9|   3|   5|   2|    4|    7|
    |   6|   8|   8|   1|   6|   7|   7|   8|   1|    5|    9|
    |   8|   9|   4|   3|   0|   3|   5|   0|   2|    3|    8|
    |   1|   3|   3|   3|   7|   0|   1|   9|   9|    0|    4|
    |   7|   3|   2|   7|   2|   0|   0|   4|   5|    5|    6|
    |   8|   4|   1|   4|   9|   8|   1|   1|   7|    9|    9|
    |   3|   6|   7|   2|   0|   3|   5|   9|   4|    4|    6|
    |   4|   4|   3|   4|   4|   8|   4|   3|   7|    5|    5|
    |   0|   1|   5|   9|   3|   0|   5|   0|   1|    2|    4|
    |   2|   0|   3|   2|   0|   7|   5|   9|   0|    2|    7|
    |   2|   9|   2|   3|   3|   2|   3|   4|   1|    2|    9|
    |   1|   4|   6|   8|   2|   3|   0|   0|   6|    0|    6|
    |   3|   3|   8|   8|   8|   2|   3|   2|   0|    8|    8|
    |   3|   8|   2|   8|   4|   3|   0|   4|   3|    6|    9|
    |   8|   0|   8|   5|   9|   0|   9|   6|   5|    3|    1|
    |   8|   0|   4|   9|   6|   5|   7|   8|   8|    9|    2|
    |   8|   6|   6|   9|   1|   6|   8|   8|   3|    2|    3|
    |   6|   3|   6|   5|   7|   0|   8|   4|   6|    5|    8|
    |   2|   3|   9|   7|   5|   3|   4|   5|   3|    3|    7|
    |   9|   9|   9|   7|   3|   2|   3|   9|   7|    7|    5|
    +----+----+----+----+----+----+----+----+----+-----+-----+

有几种方法可以将行值转换为列表:

  • 使用等于列名的键创建map,并将值等于相应的行值。

    import pyspark.sql.functions as psf
    from itertools import chain
    df = df\
        .withColumn("id", psf.monotonically_increasing_id())\
        .select(
            "id", 
            psf.posexplode(
                psf.create_map(list(chain(*[(psf.lit(c), psf.col(c)) for c in df.columns if c != "col11"])))
            ).alias("pos", "col_name", "value"), "col11")
    df.show()
        +---+---+--------+-----+-----+
        | id|pos|col_name|value|col11|
        +---+---+--------+-----+-----+
        |  0|  0|    col1|    5|    7|
        |  0|  1|    col2|    0|    7|
        |  0|  2|    col3|    3|    7|
        |  0|  3|    col4|    3|    7|
        |  0|  4|    col5|    7|    7|
        |  0|  5|    col6|    9|    7|
        |  0|  6|    col7|    3|    7|
        |  0|  7|    col8|    5|    7|
        |  0|  8|    col9|    2|    7|
        |  0|  9|   col10|    4|    7|
        |  1|  0|    col1|    6|    9|
        |  1|  1|    col2|    8|    9|
        |  1|  2|    col3|    8|    9|
        |  1|  3|    col4|    1|    9|
        |  1|  4|    col5|    6|    9|
        |  1|  5|    col6|    7|    9|
        |  1|  6|    col7|    7|    9|
        |  1|  7|    col8|    8|    9|
        |  1|  8|    col9|    1|    9|
        |  1|  9|   col10|    5|    9|
        +---+---+--------+-----+-----+
    
  • StructType

    中使用ArrayType
    df = df\
        .withColumn("id", psf.monotonically_increasing_id())\
        .select(
            "id", 
            psf.explode(
                psf.array([psf.struct(psf.lit(c).alias("col_name"), psf.col(c).alias("value")) 
                           for c in df.columns if c != "col11"])).alias("cols"), 
            "col11").select("cols.*", "col11", "id")
    df.show()
        +--------+-----+-----+---+
        |col_name|value|col11| id|
        +--------+-----+-----+---+
        |    col1|    5|    7|  0|
        |    col2|    0|    7|  0|
        |    col3|    3|    7|  0|
        |    col4|    3|    7|  0|
        |    col5|    7|    7|  0|
        |    col6|    9|    7|  0|
        |    col7|    3|    7|  0|
        |    col8|    5|    7|  0|
        |    col9|    2|    7|  0|
        |   col10|    4|    7|  0|
        |    col1|    6|    9|  1|
        |    col2|    8|    9|  1|
        |    col3|    8|    9|  1|
        |    col4|    1|    9|  1|
        |    col5|    6|    9|  1|
        |    col6|    7|    9|  1|
        |    col7|    7|    9|  1|
        |    col8|    8|    9|  1|
        |    col9|    1|    9|  1|
        |   col10|    5|    9|  1|
        +--------+-----+-----+---+
    
  • 使用ArrayType ...

获得爆炸列表后,您可以查找|col11 - value|的最小值:

from pyspark.sql import Window
w = Window.partitionBy("id").orderBy(psf.abs(psf.col("col11") - psf.col("value")))
res = df.withColumn("rn", psf.row_number().over(w)).filter("rn = 1")
res.sort("id").show()
    +--------+-----+-----+----------+---+
    |col_name|value|col11|        id| rn|
    +--------+-----+-----+----------+---+
    |    col5|    7|    7|         0|  1|
    |    col2|    8|    9|         1|  1|
    |    col1|    8|    8|         2|  1|
    |    col2|    3|    4|         3|  1|
    |    col1|    7|    6|         4|  1|
    |    col5|    9|    9|         5|  1|
    |    col2|    6|    6|         6|  1|
    |   col10|    5|    5|         7|  1|
    |    col3|    5|    4|         8|  1|
    |    col6|    7|    7|         9|  1|
    |    col2|    9|    9|8589934592|  1|
    |    col3|    6|    6|8589934593|  1|
    |    col3|    8|    8|8589934594|  1|
    |    col2|    8|    9|8589934595|  1|
    |    col2|    0|    1|8589934596|  1|
    |    col2|    0|    2|8589934597|  1|
    |    col9|    3|    3|8589934598|  1|
    |    col7|    8|    8|8589934599|  1|
    |    col4|    7|    7|8589934600|  1|
    |    col4|    7|    5|8589934601|  1|
    +--------+-----+-----+----------+---+