我有一个pyspark(v2.4.5)数据框中的句子列表,其中有一组匹配的分数。句子和分数采用列表形式。
df=spark.createDataFrame(
[
(1, ['foo1','foo2','foo3'],[0.1,0.5,0.6]), # create your data here, be consistent in the types.
(2, ['bar1','bar2','bar3'],[0.5,0.7,0.7]),
(3, ['baz1','baz2','baz3'],[0.1,0.2,0.3]),
],
['id', 'txt','score'] # add your columns label here
)
df.show()
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
| 3|[baz1, baz2, baz3]|[0.1, 0.2, 0.3]|
+---+------------------+---------------+
我只想过滤并返回分数> = 0.5的句子。
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1| [foo2, foo3]| [0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+
有什么建议吗?
我尝试了pyspark dataframe filter or include based on list,但无法在我的实例中正常工作
答案 0 :(得分:3)
使用spark 2.4+,您可以访问高阶函数,因此可以在有条件的情况下对压缩数组进行过滤,然后过滤出空白数组:
import pyspark.sql.functions as F
e = F.expr('filter(arrays_zip(txt,score),x-> x.score>=0.5)')
df.withColumn("txt",e.txt).withColumn("score",e.score).filter(F.size(e)>0).show()
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1| [foo2, foo3]| [0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+
答案 1 :(得分:0)
尝试一下,没有UDF,我想不出一种方法:
from pyspark.sql.types import ArrayType, BooleanType, StringType()
# UDF for boolean index
filter_udf = udf(lambda arr: [True if x >= 0.5 else False for x in arr], ArrayType(BooleanType()))
# UDF for filtering on the boolean index
filter_udf_bool = udf(lambda col_arr, bool_arr: [x for (x,y) in zip(col_arr,bool_arr) if y], ArrayType(StringType()))
df2 = df.withColumn("test", filter_udf("score"))
df3 = df2.withColumn("txt", filter_udf_bool("txt", "test")).withColumn("score", filter_udf_bool("score", "test"))
输出:
# Further filtering for empty arrays:
df3.drop("test").filter(F.size(F.col("txt")) > 0).show()
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1| [foo2, foo3]| [0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+
实际上,您也可以通过将UDF组合成一个整体来概括UDF。为了简单起见,我将其拆分。
答案 2 :(得分:-1)
在spark中,用户定义的函数被视为黑盒,因为催化剂优化器无法优化udf内部的代码。因此,请尽可能避免使用udf。
这里是不使用UDF的示例
df.withColumn('combined',f.explode(f.arrays_zip('txt','score'))).filter(f.col('combined.score')>=0.5).groupby('id').agg(f.collect_list('combined.txt').alias('txt'),f.collect_list('combined.score').alias('score')).show()
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1| [foo2, foo3]| [0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+
希望它能起作用。
答案 3 :(得分:-2)
列score
是一种数组类型,需要使用谓词进一步过滤。
用于过滤数组列的代码段:
def score_filter(row):
score_filtered = [s for s in row.score if s >= 0.5]
if len(score_filtered) > 0:
yield row
filtered = df.rdd.flatMap(score_filter).toDF()
filtered.show()
输出:
+---+------------------+---------------+
| id| txt| score|
+---+------------------+---------------+
| 1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
| 2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+