我有一个带有一组计算列的Apache Spark数据帧。对于数据框中的每一行(大约2000),我希望获取10列的行值,并找到第11列相对于其他10的最接近的值。
我想我会取这些行值并将其转换为列表然后使用abs值计算来确定最接近的值。
但我仍然坚持如何将行值转换为列表。我已经采用了一个列并使用collect_list将这些值转换为列表但不确定当列表来自单行和多列时如何处理。
答案 0 :(得分:1)
您应该explode
列,以便线性化您的计算。
让我们创建一个示例数据框:
import numpy as np
np.random.seed(0)
df = sc.parallelize([np.random.randint(0, 10, 11).tolist() for _ in range(20)])\
.toDF(["col" + str(i) for i in range(1, 12)])
df.show()
+----+----+----+----+----+----+----+----+----+-----+-----+
|col1|col2|col3|col4|col5|col6|col7|col8|col9|col10|col11|
+----+----+----+----+----+----+----+----+----+-----+-----+
| 5| 0| 3| 3| 7| 9| 3| 5| 2| 4| 7|
| 6| 8| 8| 1| 6| 7| 7| 8| 1| 5| 9|
| 8| 9| 4| 3| 0| 3| 5| 0| 2| 3| 8|
| 1| 3| 3| 3| 7| 0| 1| 9| 9| 0| 4|
| 7| 3| 2| 7| 2| 0| 0| 4| 5| 5| 6|
| 8| 4| 1| 4| 9| 8| 1| 1| 7| 9| 9|
| 3| 6| 7| 2| 0| 3| 5| 9| 4| 4| 6|
| 4| 4| 3| 4| 4| 8| 4| 3| 7| 5| 5|
| 0| 1| 5| 9| 3| 0| 5| 0| 1| 2| 4|
| 2| 0| 3| 2| 0| 7| 5| 9| 0| 2| 7|
| 2| 9| 2| 3| 3| 2| 3| 4| 1| 2| 9|
| 1| 4| 6| 8| 2| 3| 0| 0| 6| 0| 6|
| 3| 3| 8| 8| 8| 2| 3| 2| 0| 8| 8|
| 3| 8| 2| 8| 4| 3| 0| 4| 3| 6| 9|
| 8| 0| 8| 5| 9| 0| 9| 6| 5| 3| 1|
| 8| 0| 4| 9| 6| 5| 7| 8| 8| 9| 2|
| 8| 6| 6| 9| 1| 6| 8| 8| 3| 2| 3|
| 6| 3| 6| 5| 7| 0| 8| 4| 6| 5| 8|
| 2| 3| 9| 7| 5| 3| 4| 5| 3| 3| 7|
| 9| 9| 9| 7| 3| 2| 3| 9| 7| 7| 5|
+----+----+----+----+----+----+----+----+----+-----+-----+
有几种方法可以将行值转换为列表:
使用等于列名的键创建map
,并将值等于相应的行值。
import pyspark.sql.functions as psf
from itertools import chain
df = df\
.withColumn("id", psf.monotonically_increasing_id())\
.select(
"id",
psf.posexplode(
psf.create_map(list(chain(*[(psf.lit(c), psf.col(c)) for c in df.columns if c != "col11"])))
).alias("pos", "col_name", "value"), "col11")
df.show()
+---+---+--------+-----+-----+
| id|pos|col_name|value|col11|
+---+---+--------+-----+-----+
| 0| 0| col1| 5| 7|
| 0| 1| col2| 0| 7|
| 0| 2| col3| 3| 7|
| 0| 3| col4| 3| 7|
| 0| 4| col5| 7| 7|
| 0| 5| col6| 9| 7|
| 0| 6| col7| 3| 7|
| 0| 7| col8| 5| 7|
| 0| 8| col9| 2| 7|
| 0| 9| col10| 4| 7|
| 1| 0| col1| 6| 9|
| 1| 1| col2| 8| 9|
| 1| 2| col3| 8| 9|
| 1| 3| col4| 1| 9|
| 1| 4| col5| 6| 9|
| 1| 5| col6| 7| 9|
| 1| 6| col7| 7| 9|
| 1| 7| col8| 8| 9|
| 1| 8| col9| 1| 9|
| 1| 9| col10| 5| 9|
+---+---+--------+-----+-----+
在StructType
ArrayType
df = df\
.withColumn("id", psf.monotonically_increasing_id())\
.select(
"id",
psf.explode(
psf.array([psf.struct(psf.lit(c).alias("col_name"), psf.col(c).alias("value"))
for c in df.columns if c != "col11"])).alias("cols"),
"col11").select("cols.*", "col11", "id")
df.show()
+--------+-----+-----+---+
|col_name|value|col11| id|
+--------+-----+-----+---+
| col1| 5| 7| 0|
| col2| 0| 7| 0|
| col3| 3| 7| 0|
| col4| 3| 7| 0|
| col5| 7| 7| 0|
| col6| 9| 7| 0|
| col7| 3| 7| 0|
| col8| 5| 7| 0|
| col9| 2| 7| 0|
| col10| 4| 7| 0|
| col1| 6| 9| 1|
| col2| 8| 9| 1|
| col3| 8| 9| 1|
| col4| 1| 9| 1|
| col5| 6| 9| 1|
| col6| 7| 9| 1|
| col7| 7| 9| 1|
| col8| 8| 9| 1|
| col9| 1| 9| 1|
| col10| 5| 9| 1|
+--------+-----+-----+---+
使用ArrayType
...
获得爆炸列表后,您可以查找|col11 - value|
的最小值:
from pyspark.sql import Window
w = Window.partitionBy("id").orderBy(psf.abs(psf.col("col11") - psf.col("value")))
res = df.withColumn("rn", psf.row_number().over(w)).filter("rn = 1")
res.sort("id").show()
+--------+-----+-----+----------+---+
|col_name|value|col11| id| rn|
+--------+-----+-----+----------+---+
| col5| 7| 7| 0| 1|
| col2| 8| 9| 1| 1|
| col1| 8| 8| 2| 1|
| col2| 3| 4| 3| 1|
| col1| 7| 6| 4| 1|
| col5| 9| 9| 5| 1|
| col2| 6| 6| 6| 1|
| col10| 5| 5| 7| 1|
| col3| 5| 4| 8| 1|
| col6| 7| 7| 9| 1|
| col2| 9| 9|8589934592| 1|
| col3| 6| 6|8589934593| 1|
| col3| 8| 8|8589934594| 1|
| col2| 8| 9|8589934595| 1|
| col2| 0| 1|8589934596| 1|
| col2| 0| 2|8589934597| 1|
| col9| 3| 3|8589934598| 1|
| col7| 8| 8|8589934599| 1|
| col4| 7| 7|8589934600| 1|
| col4| 7| 5|8589934601| 1|
+--------+-----+-----+----------+---+