包含 StringType 元素的 ArrayType 列上的 UDF 函数

时间:2021-02-09 03:47:52

标签: apache-spark pyspark apache-spark-sql user-defined-functions pyspark-dataframes

我需要一个 udf 函数来输入数据帧的数组列并对其中的两个字符串元素执行相等性检查。我的数据框有这样的架构。

<头>
ID 日期 选项
1 2021-01-06 ['red', 'green']
2 2021-01-07 ['蓝色', '蓝色']
3 2021-01-08 ['蓝色', '黄色']
4 2021-01-09 nan

我已经试过了:

def equality_check(options: list):
  try:
   if options[0] == options[1]:
     return 1
   else:
     return 0
  except:
     return -1

equality_udf = f.udf(equality_check, t.IntegerType())

但它抛出了索引错误。我相信 options 列是字符串数组。 期望是这样的:

<头>
ID 日期 选项 equality_check
1 2021-01-06 ['red', 'green'] 0
2 2021-01-07 ['蓝色', '蓝色'] 1
3 2021-01-08 ['蓝色', '黄色'] 0
4 2021-01-09 nan -1

1 个答案:

答案 0 :(得分:1)

您可以检查 options 列表是否已定义或其长度是否小于 2,而不是使用 try/except。这是一个工作示例:

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

data = [
    (1, "2021-01-06", ['red', 'green']),
    (2, "2021-01-07", ['Blue', 'Blue']),
    (3, "2021-01-08", ['Blue', 'Yellow']),
    (4, "2021-01-09", None),
]
df = spark.createDataFrame(data, ["ID", "date", "options"])

def equality_check(options: list):
    if not options or len(options) < 2:
        return -1

    return int(options[0] == options[1])

equality_udf = F.udf(equality_check, IntegerType())

df1 = df.withColumn("equality_check", equality_udf(F.col("options")))
df1.show()

#+---+----------+--------------+--------------+
#| ID|      date|       options|equality_check|
#+---+----------+--------------+--------------+
#|  1|2021-01-06|  [red, green]|             0|
#|  2|2021-01-07|  [Blue, Blue]|             1|
#|  3|2021-01-08|[Blue, Yellow]|             0|
#|  4|2021-01-09|          null|            -1|
#+---+----------+--------------+--------------+

但是,我建议你不要使用 UDF,因为你可以只使用内置函数来做同样的事情:

df1 = df.withColumn(
    "equality_check",
    F.when(F.size(F.col("options")) < 2, -1)
        .when(F.col("options")[0] == F.col("options")[1], 1)
        .otherwise(0)
)