将多列与另一列进行比较时,立即选择较小/较大的值

时间:2021-01-25 05:18:16

标签: python apache-spark pyspark apache-spark-sql pyspark-dataframes

我有可变数量的列,假设在这个例子中,我们有 4 列要比较 (textX) 与具有不同值的单个列 (id):

d =     [
  {'id':  500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000}, 
  {'id': 1500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 2500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 3500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 4500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000},
  {'id': 5500, 'text1': 1000 ,'text2': 2000 ,'text3': 3000, 'text4': 5000}
] 
data = spark.createDataFrame(d)

我想根据“id”的值对 textX 列中的最小和较大值进行操作。 例如,对于 id value = 2500,我想对值 2000 和 3000 进行操作。在 'id' 值为 500 的情况下,它将是 null 和 1000。 我试图将这些作为附加列,例如以获得较低的列值

df_cols = data.columns
thresh_list = [x for x in df_cols if x.startswith('text')]

data.withColumn('inic_th', (col(x) for x in thresh_list if col('id') > col(x)))

但得到一个错误:

<块引用>

col 应该是 Column

我猜这是因为有不止一列符合条件但无法在此处插入。

有没有人有任何解决方案可以根据第三列将操作转换为 2 个值,或者如何正确获得这些边界?实际上,textX 列的数量会有所不同。由于性能问题,我尽可能远离 Pandas 和 UDF。

3 个答案:

答案 0 :(得分:2)

您可以使用 leastgreatest 获取相关列:

import pyspark.sql.functions as F

df = data.withColumn(
    'col1',
    F.greatest(*[
        F.when(F.col(c) < F.col('id'), F.col(c))
        for c in data.columns
    ])
).withColumn(
    'col2',
    F.least(*[
        F.when(F.col(c) > F.col('id'), F.col(c))
        for c in data.columns
    ])
)

df.show()
+----+-----+-----+-----+-----+----+----+
|  id|text1|text2|text3|text4|col1|col2|
+----+-----+-----+-----+-----+----+----+
| 500| 1000| 2000| 3000| 5000|null|1000|
|1500| 1000| 2000| 3000| 5000|1000|2000|
|2500| 1000| 2000| 3000| 5000|2000|3000|
|3500| 1000| 2000| 3000| 5000|3000|5000|
|4500| 1000| 2000| 3000| 5000|3000|5000|
|5500| 1000| 2000| 3000| 5000|5000|null|
+----+-----+-----+-----+-----+----+----+

然后您可以对col1col2进行操作。

答案 1 :(得分:2)

这是一种使用高阶函数的方法,用于 spark >=2.4:


df_cols = data.columns
thresh_list = [x for x in df_cols if x.startswith('text')]

out = (data.select("*",F.sort_array(F.array(*thresh_list)).alias("Arr"))
.withColumn("FirstVal",F.expr('element_at(filter (Arr, x-> x<id),-1)'))
.withColumn("LastVal",F.expr('filter (Arr, x->x>id)[0]')).drop("Arr")
)

out.show(truncate=False)

+----+-----+-----+-----+-----+--------+-------+
|id  |text1|text2|text3|text4|FirstVal|LastVal|
+----+-----+-----+-----+-----+--------+-------+
|500 |1000 |2000 |3000 |5000 |null    |1000   |
|1500|1000 |2000 |3000 |5000 |1000    |2000   |
|2500|1000 |2000 |3000 |5000 |2000    |3000   |
|3500|1000 |2000 |3000 |5000 |3000    |5000   |
|4500|1000 |2000 |3000 |5000 |3000    |5000   |
|5500|1000 |2000 |3000 |5000 |5000    |null   |
+----+-----+-----+-----+-----+--------+-------+

答案 2 :(得分:2)

这是使用 array_maxarray_min 函数以及 when 表达式的另一种方法:

  • lowerBound = max thresh_cols 满足条件 thresh_col < id
  • upperBound = min thresh_cols 满足条件 thresh_col > id
from pyspark.sql import functions as F

result = data.withColumn(
    'lowerBound',
    F.array_max(F.array(*[F.when(F.col(c) < F.col('id'), F.col(c)) for c in thresh_cols]))
).withColumn(
    'upperBound',
    F.array_min(F.array(*[F.when(F.col(c) > F.col('id'), F.col(c)) for c in thresh_cols]))
)

result.show()

#+----+-----+-----+-----+-----+----------+----------+
#|  id|text1|text2|text3|text4|lowerBound|upperBound|
#+----+-----+-----+-----+-----+----------+----------+
#| 500| 1000| 2000| 3000| 5000|      null|      1000|
#|1500| 1000| 2000| 3000| 5000|      1000|      2000|
#|2500| 1000| 2000| 3000| 5000|      2000|      3000|
#|3500| 1000| 2000| 3000| 5000|      3000|      5000|
#|4500| 1000| 2000| 3000| 5000|      3000|      5000|
#|5500| 1000| 2000| 3000| 5000|      5000|      null|
#+----+-----+-----+-----+-----+----------+----------+
相关问题