使用地图迭代PySpark中的数组列

时间:2019-06-20 15:14:02

标签: python apache-spark pyspark

在PySpark中,我有一个由两列组成的数据框:

+-----------+----------------------+
| str1      | array_of_str         |
+-----------+----------------------+
| John      | [mango, apple, ...   |
| Tom       | [mango, orange, ...  |
| Matteo    | [apple, banana, ...  | 

我想添加一列concat_result,其中包含array_of_str中每个元素的串联,并在str1列中包含字符串。

+-----------+----------------------+----------------------------------+
| str1      | array_of_str         | concat_result                    |
+-----------+----------------------+----------------------------------+
| John      | [mango, apple, ...   | [mangoJohn, appleJohn, ...       |
| Tom       | [mango, orange, ...  | [mangoTom, orangeTom, ...        |
| Matteo    | [apple, banana, ...  | [appleMatteo, bananaMatteo, ...  |

我正在尝试使用map遍历数组:

from pyspark.sql import functions as F
from pyspark.sql.types import StringType, ArrayType

# START EXTRACT OF CODE
ret = (df
  .select(['str1', 'array_of_str'])
  .withColumn('concat_result', F.udf(
     map(lambda x: x + F.col('str1'), F.col('array_of_str')), ArrayType(StringType))
  )
)

return ret
# END EXTRACT OF CODE

但是我得到了错误:

TypeError: argument 2 to map() must support iteration

1 个答案:

答案 0 :(得分:2)

您只需要进行一些微调即可完成这项工作:

from pyspark.sql.types import StringType, ArrayType
from pyspark.sql.functions import udf, col

concat_udf = udf(lambda con_str, arr: [x + con_str for x in arr],
                   ArrayType(StringType()))
ret = df \
  .select(['str1', 'array_of_str']) \
  .withColumn('concat_result', concat_udf(col("str1"), col("array_of_str")))

ret.show()

您不需要使用map,标准列表理解就足够了。