熊猫用户定义函数(UDF)-是否可以返回布尔值?

时间:2020-02-26 15:02:40

标签: apache-spark pyspark pyspark-sql pyspark-dataframes

我正在尝试将函数编写为Pandas UDF,它将检查字符串数组的任何元素是否以特定值开头。我正在寻找的结果将是这样的:

df.filter(list_contains(val, df.stringArray_column)).show()

list_contains函数将在df.stringArray的任何元素以val开头的每一行上返回 True

只是一个例子:

df = spark.read.csv(path)
display(df.filter(list_contains('50', df.stringArray_column)))

上面的代码将显示df的每一行,其中stringArray列的元素以50开头。

我已经在python中编写了一个函数,该函数非常慢

    def list_contains(val):
    # Perfom what ListContains generated
  def list_contains_udf(column_list):
    for element in column_list:
      if element.startswith(val):
        return True
    return False
  return udf(list_contains_udf, BooleanType())

谢谢您的帮助。

编辑:这是一些示例数据,也是我正在寻找的输出示例:

df.LIST: ["408000","641100"]
         ["633400","641100"]
         ["633400","791100"]
         ["633400","408100"]
         ["633400","641100"]
         ["408110","641230"]
         ["633400","647200"]

display(df.select('LIST').filter(list_contains('408')(df.LIST)))

output: LIST
        ["408000","641100"]
        ["633400","408100"]
        ["408110","641230"]

1 个答案:

答案 0 :(得分:0)

更新后的答案

如果我们假设数组的长度相同,则可以在没有UDF的情况下执行此操作。让我们尝试以下方法。

from pyspark.sql import SparkSession
import pyspark.sql.functions as f
from pyspark.sql.functions import col

spark = SparkSession.builder.appName('prefix_finder').getOrCreate()

# sample data creation
my_df = spark.createDataFrame(
    [('scooby', ['cartoon', 'kidfriendly']),
     ('batman', ['dark', 'cars']),
     ('meshuggah', ['heavy', 'dark']),
     ('guthrie', ['god', 'guitar'])
     ]
    , schema=('character', 'tags'))

数据帧my_df如下所示:

+---------+----------------------+
|character|tags                  |
+---------+----------------------+
|scooby   |[cartoon, kidfriendly]|
|batman   |[dark, cars]          |
|meshuggah|[heavy, dark]         |
|guthrie  |[god, guitar]         |
+---------+----------------------+

如果我们要搜索前缀 car ,则仅返回第一行和第二行,因为 car cartoon 和汽车

以下是本机Spark操作可以实现的目标。

num_items_in_arr = 2 # this was the assumption
prefix = 'car'

my_df2 = my_df.select(col('character'), col('tags'), *(col('tags').getItem(i).alias(f'tag{i}') for i in range(num_items_in_arr)))

数据帧my_df2如下:

+---------+----------------------+-------+-----------+
|character|tags                  |tag0   |tag1       |
+---------+----------------------+-------+-----------+
|scooby   |[cartoon, kidfriendly]|cartoon|kidfriendly|
|batman   |[dark, cars]          |dark   |cars       |
|meshuggah|[insane, heavy]       |insane |heavy      |
|guthrie  |[god, guitar]         |god    |guitar     |
+---------+----------------------+-------+-----------+

让我们在my_df2上创建一列 concat_tags ,我们将其用于正则表达式匹配。

cols_of_interest = [f'tag{i}' for i in range(num_items_in_arr)]

for idx, col_name in enumerate(cols_of_interest):
    my_df2 = my_df2.withColumn(col_name, f.substring(col_name, 1, prefix_len))

    if idx == 0:
        my_df2 = my_df2.withColumn(col_name, f.concat(lit("("), col_name, lit(".*")))
    elif idx == len(cols_to_update_concat) - 1:
        my_df2 = my_df2.withColumn(col_name, f.concat(col_name, lit(".*)")))
    else:
        my_df2 = my_df2.withColumn(col_name, f.concat(col_name, lit(".*")))

my_df3 = my_df2.withColumn('concat_tags', f.concat_ws('|', *cols_of_interest)).drop(*cols_of_interest)

my_df3如下:

+---------+----------------------+-------------+
|character|tags                  |concat_tags  |
+---------+----------------------+-------------+
|scooby   |[cartoon, kidfriendly]|(car.*|kid.*)|
|batman   |[dark, cars]          |(dar.*|car.*)|
|meshuggah|[insane, heavy]       |(ins.*|hea.*)|
|guthrie  |[god, guitar]         |(god.*|gui.*)|
+---------+----------------------+-------------+

现在,我们需要在 concat_tags 列上应用正则表达式数学。

my_df4 = my_df3.withColumn('matched', f.expr(r"regexp_extract(prefix, concat_tags, 0)"))

结果如下:

+---------+----------------------+-------------+-------+
|character|tags                  |concat_tags  |matched|
+---------+----------------------+-------------+-------+
|scooby   |[cartoon, kidfriendly]|(car.*|kid.*)|car    |
|batman   |[dark, cars]          |(dar.*|car.*)|car    |
|meshuggah|[insane, heavy]       |(ins.*|hea.*)|       |
|guthrie  |[god, guitar]         |(god.*|gui.*)|       |
+---------+----------------------+-------------+-------+

一点点清理。

my_df5 = my_df4.filter(my_df4.matched != "").drop('concat_tags', 'matched')

这是最终数据帧:

+---------+----------------------+
|character|tags                  |
+---------+----------------------+
|scooby   |[cartoon, kidfriendly]|
|batman   |[dark, cars]          |
+---------+----------------------+