我有一个看起来像这样的火花数据框。
id cd1 version1 dt1 cd2 version2 dt2 cd3 version3 dt3
1 100 1 20100101 101 1 20100101 102 20100301
1 101 1 20100102 102 20100201 100 1 20100302
2 201 1 20100103 100 1 20100301 100 1 20100303
2 202 2 20100104 100 1 20100105
我需要将所有代码转换为具有以下条件的单个列
对于上面的示例,输出应该如下所示。
id code dt
1 1.00 20100101
1 1.01 20100101
1 102 20100301
1 1.01 20100102
1 102 20100201
1 10.0 20100302
2 2.01 20100103
2 1.00 20100301
2 1.00 20100303
2 202 20100104
2 10.0 20100105
我正在使用Pyspark这样做。在上面的例子中,我只显示了3个代码及其相应的版本列,但我有30个这样的列。此外,这些数据有大约2500万行。
关于如何实现这一目标的任何想法都将非常有用。
答案 0 :(得分:0)
您可以explode
这些列的列表,这样每行只有一个(cd, version)
对
首先,让我们创建数据框:
df = sc.parallelize([[1,100,1,101,1,102,None],[1,101,1,102,None,100,1],[2,201,1,100,1,100,1],
[2,202,2,100,1,None,None]]).toDF(["id","cd1","version1","cd2","version2","cd3","version3"])
使用posexplode
:
import pyspark.sql.functions as psf
from itertools import chain
nb_versions = 4
df = df.na.fill(-1).select(
"id",
psf.posexplode(psf.create_map(list(chain(*[(psf.col("cd" + str(i)), psf.col("version"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "cd", "version")
).drop("pos").filter("cd != -1")
+---+---+-------+
| id| cd|version|
+---+---+-------+
| 1|100| 1|
| 1|101| 1|
| 1|102| -1|
| 1|101| 1|
| 1|102| -1|
| 1|100| 1|
| 2|201| 1|
| 2|100| 1|
| 2|100| 1|
| 2|202| 2|
| 2|100| 1|
+---+---+-------+
使用explode
:
nb_versions = 4
df = df.select(
"id",
psf.explode(psf.array(
[psf.struct(
psf.col("cd" + str(i)).alias("cd"),
psf.col("version" + str(i)).alias("version")) for i in range(1, nb_versions)])).alias("temp"))\
.select("id", "temp.*")
+---+----+-------+
| id| cd|version|
+---+----+-------+
| 1| 100| 1|
| 1| 101| 1|
| 1| 102| null|
| 1| 101| 1|
| 1| 102| null|
| 1| 100| 1|
| 2| 201| 1|
| 2| 100| 1|
| 2| 100| 1|
| 2| 202| 2|
| 2| 100| 1|
| 2|null| null|
+---+----+-------+
现在我们可以实现您的条件
我们会使用函数when, otherwise
作为条件,distinct
:
df.withColumn("cd", psf.when(df.version == 1, df.cd/100).otherwise(df.cd))\
.distinct().drop("version")
+---+-----+
| id| cd|
+---+-----+
| 1| 1.0|
| 1| 1.01|
| 1|102.0|
| 2| 1.0|
| 2| 2.01|
| 2|202.0|
+---+-----+
答案 1 :(得分:-1)
这就是我做的。我相信有更好的方法可以做到这一点。
def process_code(raw_data):
for i in range(1,4):
cd_col_name = "cd" + str(i)
version_col_name = "version" + str(i)
raw_data = raw_data.withColumn("mod_cd" + str(i), when(raw_data[version_col_name] == 1, concat(substring(raw_data[cd_col_name],1,1),lit("."),substring(raw_data[cd_col_name],2,20))).otherwise(raw_data[cd_col_name]))
mod_cols = [col for col in raw_data.columns if 'mod_cd' in col]
nb_versions = 3
new = raw_data.fillna('9999', subset=mod_cols).select("id", psf.posexplode(psf.create_map(list(chain(*[(psf.col("mod_cd" + str(i)), psf.col("dt"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "final_cd", "final_date")).drop("pos")
return new
test = process_code(df)
test = test.filter(test.final_cd != '9999')
test.show(100, False)