将函数应用于数组列pyspark中的所有值

时间:2019-10-22 12:31:46

标签: arrays apache-spark pyspark user-defined-functions

我想使pyspark数据框中的数组列中的所有值均为负数而不爆炸(!)。我尝试了这个udf,但是没有用:

negative = func.udf(lambda x: x * -1, T.ArrayType(T.FloatType()))
cast_contracts = cast_contracts \
    .withColumn('forecast_values', negative('forecast_values'))

有人可以帮忙吗?

示例数据框:

df = sc..parallelize(
   [Row(name='Joe', forecast_values=[1.0,2.0,3.0]),
    Row(name='Mary', forecast_values=[4.0,7.1])]).toDF()
>>> df.show()
    +----+---------------+
    |name|forecast_values|
    +----+---------------+
    | Joe|[1.0, 2.0, 3.0]|
    |Mary|     [4.0, 7.1]|
    +----+---------------+

谢谢

2 个答案:

答案 0 :(得分:1)

只是您没有遍历列表值以将它们乘以-1

import pyspark.sql.functions as F
import pyspark.sql.types as T

negative = F.udf(lambda x: [i * -1 for i in x], T.ArrayType(T.FloatType()))
cast_contracts = df \
    .withColumn('forecast_values', negative('forecast_values'))

您无法逃脱udf,但最好的方法是逃脱这种情况。如果列表很大,效果会更好:

import numpy as np

negative = F.udf(lambda x: np.negative(x).tolist(), T.ArrayType(T.FloatType()))
cast_contracts = abdf \
    .withColumn('forecast_values', negative('forecast_values'))
cast_contracts.show()
+------------------+----+
|   forecast_values|name|
+------------------+----+
|[-1.0, -2.0, -3.0]| Joe|
|            [-3.0]|Mary|
|      [-4.0, -7.1]|Mary|
+------------------+----+

答案 1 :(得分:0)

我知道这是一个发布了一年的帖子,所以我要提供的解决方案以前可能不是一个选项(Spark 3的新功能)。如果您在PySpark API中使用spark 3.0及更高版本,则应考虑在spark.sql.function.transform内使用pyspark.sql.functions.expr。 请不要将spark.sql.function.transform与PySpark的transform()链接混淆。无论如何,这是解决方案:

df.withColumn("negative", F.expr("transform(forecast_values, x -> x * -1)"))

您只需要确保将值转换为int或float。强调的方法比爆炸数组或使用UDF效率更高。