pyspark udf在条件满足时进行计数

时间:2020-01-14 22:06:12

标签: python-3.x pandas dataframe apache-spark pyspark

问题:我有一个pyspark数据框,我想按列进行汇总,并为每个ID满足特定条件的情况提供一个计数。我的数据集如下:

my_dict = {'ID': {0: u'00319383',
  1: u'00337642',
  2: u'0346945',
  3: u'00400193',
  4: u'00405079',
  5: u'0426407',
  6: u'00445573',
  7: u'00485834',
  8: u'0493307',
  9: u'00501281'},
 'type_A': {0: u'A',
  1: u'A',
  2: u'A',
  3: u'A',
  4: u'A',
  5: u'A',
  6: u'A',
  7: u'A',
  8: u'A',
  9: u'A'},
 'type_B': {0: u'None',
  1: u'B',
  2: u'None',
  3: u'None',
  4: u'None',
  5: u'None',
  6: u'None',
  7: u'None',
  8: u'B',
  9: u'None'},
 'type_C': {0: u'C',
  1: u'C',
  2: u'C',
  3: u'C',
  4: u'C',
  5: u'C',
  6: u'C',
  7: u'C',
  8: u'C',
  9: u'C'},
 'type_D': {0: u'None',
  1: u'None',
  2: u'None',
  3: u'None',
  4: u'None',
  5: u'None',
  6: u'None',
  7: u'D',
  8: u'None',
  9: u'None'}}

目标是通过ID计算产品的出现次数。我用SQL开发了一种解决方案,该解决方案可以满足我的要求:

spark.sql('''
            select total, count(contract_id) as freq
            from 
            (
                select id, (typeA + typeB + typeC + typeD) as total
                from
                    (
                        select id
                        , case when type_A = 'A' then 1 else 0 end as typeA
                        , case when type_B = 'B' then 1 else 0 end as typeB 
                        , case when type_C = 'C' then 1 else 0 end as typeC  
                        , case when type_D = 'D' then 1 else 0 end as typeD  
                        from df 
                    ) a
            ) b

            group by total

         ''').toPandas()

如何使用python / pyspark函数完成此操作?寻找解决此类问题的想法?

1 个答案:

答案 0 :(得分:0)

好吧,这应该可以解决问题:

from pyspark.sql.functions import *
df.select(
    (when(col("type_A") == lit("A"), lit(1)).otherwise(lit(0))+
    when(col("type_B") == lit("B"), lit(1)).otherwise(lit(0))+
    when(col("type_C") == lit("C"), lit(1)).otherwise(lit(0))+
    when(col("type_D") == lit("D"), lit(1)).otherwise(lit(0))).alias("total"), col("id")
).groupBy("total").agg(count(col("id")).alias("freq"))