假设我有一个非常基本的Spark DataFrame,它由几列组成,其中一列包含我想要修改的值。
|| value || lang ||
| 3 | en |
| 4 | ua |
说,我希望每个特定的类都有一个新列,我会在给定的值中添加一个浮点数(这与最终的问题没什么关系,但实际上我在那里用sklearn进行预测,但对于简单性让我们假设我们正在添加东西,我的想法是以某种方式修改价值)。所以给定一个dict classes={'1':2.0, '2':3.0}
我希望每个类都有一个列,我将DF中的值添加到类的值中,然后将其保存到csv:
class_1.csv
|| value || lang || my_class | modified ||
| 3 | en | 1 | 5.0 | # this is 3+2.0
| 4 | ua | 1 | 6.0 | # this is 4+2.0
class_2.csv
|| value || lang || my_class | modified ||
| 3 | en | 2 | 6.0 | # this is 3+3.0
| 4 | ua | 2 | 7.0 | # this is 4+3.0
到目前为止,我有以下代码可以工作并修改每个已定义类的值,但它是通过for循环完成的,我正在寻找更高级的优化:
import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
from pyspark.sql.functions import udf
from pyspark.sql.functions import lit
# create session and context
spark = pyspark.sql.SparkSession.builder.master("yarn").appName("SomeApp").getOrCreate()
conf = SparkConf().setAppName('Some_App').setMaster("local[*]")
sc = SparkContext.getOrCreate(conf)
my_df = spark.read.csv("some_file.csv")
# modify the value here
def do_stuff_to_column(value, separate_class):
# do stuff to column, let's pretend we just add a specific value per specific class that is read from a dictionary
class_dict = {'1':2.0, '2':3.0} # would be loaded from somewhere
return float(value+class_dict[separate_class])
# iterate over each given class later
class_dict = {'1':2.0, '2':3.0} # in reality have more than 10 classes
# create a udf function
udf_modify = udf(do_stuff_to_column, FloatType())
# loop over each class
for my_class in class_dict:
# create the column first with lit
my_df2 = my_df.withColumn("my_class", lit(my_class))
# modify using udf function
my_df2 = my_df2.withColumn("modified", udf_modify("value","my_class"))
# write to csv now
my_df2.write.format("csv").save("class_"+my_class+".csv")
所以问题是,在for循环中有没有更好/更快的方法呢?
答案 0 :(得分:1)
我会使用某种形式的join
,在这种情况下crossJoin
。这是一个MWE:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([(3, 'en'), (4, 'ua')], ['value', 'lang'])
classes = spark.createDataFrame([(1, 2.), (2, 3.)], ['class_key', 'class_value'])
res = df.crossJoin(classes).withColumn('modified', F.col('value') + F.col('class_value'))
res.show()
对于单独的CSV保存,我认为没有比使用循环更好的方法了。