使用条件在Pyspark数据帧中转置多列

时间:2017-10-31 17:48:17

标签: python apache-spark pyspark apache-spark-sql spark-dataframe

我有一个看起来像这样的火花数据框。

id cd1 version1   dt1   cd2 version2  dt2      cd3 version3    dt3
1  100    1    20100101 101    1     20100101  102            20100301        
1  101    1    20100102 102          20100201  100    1       20100302
2  201    1    20100103 100    1     20100301  100    1       20100303
2  202    2    20100104 100    1     20100105

我需要将所有代码转换为具有以下条件的单个列

  • 如果相应的版本代码为1,请在第一个数字后添加小数点
  • 每位患者都应该有不同的代码

对于上面的示例,输出应该如下所示。

id    code     dt
1     1.00   20100101
1     1.01   20100101
1     102    20100301
1     1.01   20100102
1     102    20100201
1     10.0   20100302
2     2.01   20100103
2     1.00   20100301
2     1.00   20100303
2     202    20100104
2     10.0   20100105

我正在使用Pyspark这样做。在上面的例子中,我只显示了3个代码及其相应的版本列,但我有30个这样的列。此外,这些数据有大约2500万行。

关于如何实现这一目标的任何想法都将非常有用。

2 个答案:

答案 0 :(得分:0)

您可以explode这些列的列表,这样每行只有一个(cd, version)对 首先,让我们创建数据框:

df = sc.parallelize([[1,100,1,101,1,102,None],[1,101,1,102,None,100,1],[2,201,1,100,1,100,1],
                               [2,202,2,100,1,None,None]]).toDF(["id","cd1","version1","cd2","version2","cd3","version3"])
  1. 使用posexplode

    import pyspark.sql.functions as psf
    from itertools import chain
    nb_versions = 4
    df = df.na.fill(-1).select(
        "id", 
        psf.posexplode(psf.create_map(list(chain(*[(psf.col("cd" + str(i)), psf.col("version"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "cd", "version")
    ).drop("pos").filter("cd != -1")
    
        +---+---+-------+
        | id| cd|version|
        +---+---+-------+
        |  1|100|      1|
        |  1|101|      1|
        |  1|102|     -1|
        |  1|101|      1|
        |  1|102|     -1|
        |  1|100|      1|
        |  2|201|      1|
        |  2|100|      1|
        |  2|100|      1|
        |  2|202|      2|
        |  2|100|      1|
        +---+---+-------+
    
  2. 使用explode

    nb_versions = 4
    df = df.select(
        "id", 
        psf.explode(psf.array(
            [psf.struct(
                psf.col("cd" + str(i)).alias("cd"), 
                psf.col("version" + str(i)).alias("version")) for i in range(1, nb_versions)])).alias("temp"))\
        .select("id", "temp.*")
    
        +---+----+-------+
        | id|  cd|version|
        +---+----+-------+
        |  1| 100|      1|
        |  1| 101|      1|
        |  1| 102|   null|
        |  1| 101|      1|
        |  1| 102|   null|
        |  1| 100|      1|
        |  2| 201|      1|
        |  2| 100|      1|
        |  2| 100|      1|
        |  2| 202|      2|
        |  2| 100|      1|
        |  2|null|   null|
        +---+----+-------+
    
  3. 现在我们可以实现您的条件

    • 除以100版本== 1
    • 不同的价值观

    我们会使用函数when, otherwise作为条件,distinct

    df.withColumn("cd", psf.when(df.version == 1, df.cd/100).otherwise(df.cd))\
        .distinct().drop("version")
    
        +---+-----+
        | id|   cd|
        +---+-----+
        |  1|  1.0|
        |  1| 1.01|
        |  1|102.0|
        |  2|  1.0|
        |  2| 2.01|
        |  2|202.0|
        +---+-----+
    

答案 1 :(得分:-1)

这就是我做的。我相信有更好的方法可以做到这一点。

def process_code(raw_data):
    for i in range(1,4):
        cd_col_name = "cd" + str(i)
        version_col_name = "version" + str(i)
        raw_data = raw_data.withColumn("mod_cd" + str(i), when(raw_data[version_col_name] == 1, concat(substring(raw_data[cd_col_name],1,1),lit("."),substring(raw_data[cd_col_name],2,20))).otherwise(raw_data[cd_col_name]))

    mod_cols = [col for col in raw_data.columns if 'mod_cd' in col]
    nb_versions = 3
    new = raw_data.fillna('9999', subset=mod_cols).select("id", psf.posexplode(psf.create_map(list(chain(*[(psf.col("mod_cd" + str(i)), psf.col("dt"+str(i))) for i in range(1, nb_versions)])))).alias("pos", "final_cd", "final_date")).drop("pos")
    return new

test = process_code(df)
test = test.filter(test.final_cd != '9999')
test.show(100, False)