我有一个“文本”列,其中存储了令牌数组。如何过滤所有这些数组,以使令牌的长度至少为三个字母?
from pyspark.sql.functions import regexp_replace, col
from pyspark.sql.session import SparkSession
spark = SparkSession.builder.getOrCreate()
columns = ['id', 'text']
vals = [
(1, ['I', 'am', 'good']),
(2, ['You', 'are', 'ok']),
]
df = spark.createDataFrame(vals, columns)
df.show()
# Had tried this but have TypeError: Column is not iterable
# df_clean = df.select('id', regexp_replace('text', [len(word) >= 3 for word
# in col('text')], ''))
# df_clean.show()
我希望看到:
id | text
1 | [good]
2 | [You, are]
答案 0 :(得分:0)
这样做,您可以决定是否排除行,我添加了一个额外的列并过滤掉了,但是选项是您自己的:
from pyspark.sql import functions as f
columns = ['id', 'text']
vals = [
(1, ['I', 'am', 'good']),
(2, ['You', 'are', 'ok']),
(3, ['ok'])
]
df = spark.createDataFrame(vals, columns)
#df.show()
df2 = df.withColumn("text_left_over", f.expr("filter(text, x -> not(length(x) < 3))"))
df2.show()
# This is the actual piece of logic you are looking for.
df3 = df.withColumn("text_left_over", f.expr("filter(text, x -> not(length(x) < 3))")).where(f.size(f.col("text_left_over")) > 0).drop("text")
df3.show()
返回:
+---+--------------+--------------+
| id| text|text_left_over|
+---+--------------+--------------+
| 1| [I, am, good]| [good]|
| 2|[You, are, ok]| [You, are]|
| 3| [ok]| []|
+---+--------------+--------------+
+---+--------------+
| id|text_left_over|
+---+--------------+
| 1| [good]|
| 2| [You, are]|
+---+--------------+
答案 1 :(得分:0)
这是解决方案
filter_length_udf = udf(lambda row: [x for x in row if len(x) >= 3], ArrayType(StringType()))
df_final_words = df_stemmed.withColumn('words_filtered', filter_length_udf(col('words')))