如何基于Pyspark中的数组列中的值创建新列

时间:2018-07-17 13:57:29

标签: python arrays apache-spark pyspark apache-spark-sql

我有以下带有代表产品代码的数据框:

testdata = [(0, ['a','b','d']), (1, ['c']), (2, ['d','e'])]
df = spark.createDataFrame(testdata, ['id', 'codes'])
df.show()
+---+---------+
| id|    codes|
+---+---------+
|  0|[a, b, d]|
|  1|      [c]|
|  2|   [d, e]|
+---+---------+

假设代码ab代表T恤,代码c代表毛衣。

tshirts = ['a','b']
sweaters = ['c']

如何创建列label,该列检查这些代码是否在array列中并返回产品名称。像这样:

+---+---------+--------+
| id|    codes|   label|
+---+---------+--------+
|  0|[a, b, d]| tshirts|
|  1|      [c]|sweaters|
|  2|   [d, e]|    none|
+---+---------+--------+

我已经尝试了很多方法,其中包括以下不起作用的方法:

codes = {
    'tshirts': ['a','b'],
    'sweaters': ['c']
}

def any_isin(ref_values, array_to_search):
    for key, values in ref_values.items():
        if any(item in array_to_search for item in values):
            return key
        else:
            return 'none'

any_isin_udf = lambda ref_values: (F.udf(lambda array_to_search: any_isin_mod(ref_values, array_to_search), StringType()))

df_labeled = df.withColumn('label', any_isin_udf(codes)(F.col('codes')))

df_labeled.show()
+---+---------+-------+
| id|    codes|  label|
+---+---------+-------+
|  0|[a, b, d]|tshirts|
|  1|      [c]|   none|
|  2|   [d, e]|   none|
+---+---------+-------+

2 个答案:

答案 0 :(得分:2)

我会用array_contains表示。让我们将输入定义为dict

from pyspark.sql.functions import expr, lit, when
from operator import and_
from functools import reduce

label_map = {"tshirts": ["a", "b"], "sweaters": ["c"]}

下一步生成表达式:

expression_map = {
   label: reduce(and_, [expr("array_contains(codes, '{}')".format(code))
   for code in codes]) for label, codes in label_map.items()
}

最后用CASE ... WHEN减小它:

label = reduce(
    lambda acc, kv: when(kv[1], lit(kv[0])).otherwise(acc),
    expression_map.items(), 
    lit(None).cast("string")
).alias("label")

结果:

df.withColumn("label", label).show()
# +---+---------+--------+                                                        
# | id|    codes|   label|
# +---+---------+--------+
# |  0|[a, b, d]| tshirts|
# |  1|      [c]|sweaters|
# |  2|   [d, e]|    null|
# +---+---------+--------+

答案 1 :(得分:0)

首选使用pyspark.sql.functions.array_contains()之类的@user10055507answer之类的非udf方法,但这是导致代码失败的原因的解释:

错误是您在循环内调用return,因此您永远不会迭代第一个key。这是修改udf以获得所需结果的一种方法:

import pyspark.sql.functions as f

codes = {
    'tshirts': ['a','b'],
    'sweaters': ['c']
}

def any_isin(ref_values, array_to_search):
    label = 'none'
    for key, values in ref_values.items():
        if any(item in array_to_search for item in values):
            label=key
            break
    return label

any_isin_udf = lambda ref_values: (
    f.udf(lambda array_to_search: any_isin(ref_values, array_to_search), StringType())
)

df_labeled = df.withColumn('label', any_isin_udf(codes)(f.col('codes')))

df_labeled.show()
#+---+---------+--------+
#| id|    codes|   label|
#+---+---------+--------+
#|  0|[a, b, d]| tshirts|
#|  1|      [c]|sweaters|
#|  2|   [d, e]|    none|
#+---+---------+--------+

更新

这是使用join的另一种非udf方法:

首先将codes字典变成表格:

import pyspark.sql.functions as f
from itertools import chain

codes_df = spark.createDataFrame(
    list(chain.from_iterable(zip([a]*len(b), b) for a, b in codes.items())),
    ["label", "code"]
)
codes_df.show()
#+--------+----+
#|   label|code|
#+--------+----+
#| tshirts|   a|
#| tshirts|   b|
#|sweaters|   c|
#+--------+----+

现在在布尔值上执行dfcodes_df的左连接,以指示codes数组是否包含代码:

df.alias('l')\
    .join(
        codes_df.alias('r'),
        how='left',
        on=f.expr('array_contains(l.codes, r.code)')
    )\
    .select('id', 'codes', 'label')\
    .distinct()\
    .show()
#+---+---------+--------+
#| id|    codes|   label|
#+---+---------+--------+
#|  2|   [d, e]|    null|
#|  0|[a, b, d]| tshirts|
#|  1|      [c]|sweaters|
#+---+---------+--------+