在PySpark中以分布式方式应用udf函数

时间:2018-04-10 13:04:42

标签: python apache-spark pyspark apache-spark-sql spark-dataframe

假设我有一个非常基本的Spark DataFrame,它由几列组成,其中一列包含我想要修改的值。

|| value   || lang ||
| 3        |  en   |
| 4        |  ua   |

说,我希望每个特定的类都有一个新列,我会在给定的值中添加一个浮点数(这与最终的问题没什么关系,但实际上我在那里用sklearn进行预测,但对于简单性让我们假设我们正在添加东西,我的想法是以某种方式修改价值)。所以给定一个dict classes={'1':2.0, '2':3.0}我希望每个类都有一个列,我将DF中的值添加到类的值中,然后将其保存到csv:

class_1.csv
|| value   || lang ||  my_class |  modified  ||
| 3        |  en   |     1      |     5.0    |  # this is 3+2.0
| 4        |  ua   |     1      |     6.0    |  # this is 4+2.0

class_2.csv
|| value   || lang ||  my_class |  modified  ||
| 3        |  en   |     2      |     6.0    |  # this is 3+3.0
| 4        |  ua   |     2      |     7.0    |  # this is 4+3.0

到目前为止,我有以下代码可以工作并修改每个已定义类的值,但它是通过for循环完成的,我正在寻找更高级的优化:

import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
from pyspark.sql.functions import udf
from pyspark.sql.functions import lit

# create session and context
spark = pyspark.sql.SparkSession.builder.master("yarn").appName("SomeApp").getOrCreate()
conf = SparkConf().setAppName('Some_App').setMaster("local[*]")
sc = SparkContext.getOrCreate(conf)

my_df = spark.read.csv("some_file.csv")

# modify the value here
def do_stuff_to_column(value, separate_class):
    # do stuff to column, let's pretend we just add a specific value per specific class that is read from a dictionary
    class_dict = {'1':2.0, '2':3.0}  # would be loaded from somewhere
    return float(value+class_dict[separate_class])

 # iterate over each given class later
 class_dict = {'1':2.0, '2':3.0}   # in reality have more than 10 classes

 # create a udf function
 udf_modify = udf(do_stuff_to_column, FloatType())

 # loop over each class
 for my_class in class_dict:
    # create the column first with lit
    my_df2 = my_df.withColumn("my_class", lit(my_class))
    # modify using udf function
    my_df2 = my_df2.withColumn("modified", udf_modify("value","my_class"))
    # write to csv now
    my_df2.write.format("csv").save("class_"+my_class+".csv")

所以问题是,在for循环中有没有更好/更快的方法呢?

1 个答案:

答案 0 :(得分:1)

我会使用某种形式的join,在这种情况下crossJoin。这是一个MWE:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([(3, 'en'), (4, 'ua')], ['value', 'lang'])
classes = spark.createDataFrame([(1, 2.), (2, 3.)], ['class_key', 'class_value'])
res = df.crossJoin(classes).withColumn('modified', F.col('value') + F.col('class_value'))
res.show()

对于单独的CSV保存,我认为没有比使用循环更好的方法了。