使用udf传递列作为参数将自定义列添加到pyspark数据帧

时间:2018-11-30 05:24:04

标签: python apache-spark pyspark apache-spark-sql user-defined-functions

我有一个包含两列的spark数据框,并且我试图添加一个新列,为这些列引用一个新值。我正在从包含正确列值的字典中获取此值

+--------------+--------------------+
|       country|                 zip|
+--------------+--------------------+
|        Brazil|                7541|
|United Kingdom|                5678|
|         Japan|                1234|
|       Denmark|                2345|
|        Canada|                4567|
|         Italy|                6031|
|        Sweden|                4205|
|        France|                6111|
|         Spain|                8555|
|         India|                2552|
+--------------+--------------------+

该国家/地区的正确值应为印度,而邮政编码应为1234,并存储在字典中

column_dict = {'country' : 'India', zip: 1234}

我正在尝试将新列的值设置为“巴西:印度,邮政编码:1234”,其中该列的值与这些值有所不同。

我已经按照以下方式尝试过,但是它返回的是空列,但是函数返回的是所需的值

     cols = list(df.columns)
     col_list = list(column_dict.keys())

def update(df, cols = cols , col_list = col_list):
   z = []
   for col1, col2 in zip(cols,col_list):
      if col1 == col2:
         if df.col1 != column_dict[col2]: 
            z.append("{'col':" + col2  + ", 'reco': " + str(column_dict[col2]) + "}")   
         else:
            z.append("{'col':" + col2  + ", 'reco': }")

my_udf = udf(lambda x: update(x, cols, col_list))
z = y.withColumn("NewValue", lit(my_udf(y, cols,col_list)))

如果我将相同的输出数据帧导出到csv值,则这些部分将附加'\'。如何以准确的方式获取列上的函数值?

1 个答案:

答案 0 :(得分:0)

一种简单的方法是从dictionaryunion()到主数据帧中创建一个数据帧,然后groupby并获得last值。在这里您可以执行以下操作:

sc = SparkContext.getOrCreate()

newDf = sc.parallelize([
    {'country' : 'India', 'zip': 1234}
]).toDF()

newDF.show()

newDF:

+-------+----+
|country| zip|
+-------+----+
|  India|1234|
+-------+----+

和finalDF:

unionDF = df.union(newDF)

unionDF.show()
+--------------+--------------------+
|       country|                 zip|
+--------------+--------------------+
|        Brazil|                7541|
|United Kingdom|                5678|
|         Japan|                1234|
|       Denmark|                2345|
|        Canada|                4567|
|         Italy|                6031|
|        Sweden|                4205|
|        France|                6111|
|         Spain|                8555|
|         India|                2552|
|         India|                1234|
+--------------+--------------------+

,最后执行groupbylast

import pyspark.sql.functions as f

finalDF = unionDF.groupbby('country').agg(f.last('zip'))

finalDF.show()

+--------------+--------------------+
|       country|                 zip|
+--------------+--------------------+
|        Brazil|                7541|
|United Kingdom|                5678|
|         Japan|                1234|
|       Denmark|                2345|
|        Canada|                4567|
|         Italy|                6031|
|        Sweden|                4205|
|        France|                6111|
|         Spain|                8555|
|         India|                1234|
+--------------+--------------------+