PySpark:使用条件过滤DataFrame

时间:2017-07-31 20:15:33

标签: pyspark

我有以下示例DataFrame:

l = [('Alice went to wonderland',), ('qwertyuiopqwert some text',), ('hello world',), ('ThisGetsFilteredToo',)]
df = spark.createDataFrame(l)


| Alice went to wonderland  |
| qwertyuiopqwert some text |
| hello world               |
| ThisGetsFilteredToo       |

鉴于此DataFrame,我想过滤掉甚至包含一个长度为>的单词的行。 15个字符。在此示例中,第2行包含单词' qwertyuiopqwert'其长度> 15.所以应该放弃。 同样,第4行也应该删除。

2 个答案:

答案 0 :(得分:0)

from pyspark.sql.functions import udf,col
from pyspark.sql.types import StringType, IntegerType, ArrayType
data = ['athshgthsc asl','sdf sdfdsadf sdf', 'arasdfa sdf','aa bb','aaa bbb ccc','dd aa bbb']
df = sqlContext.createDataFrame(data,StringType())

def getLenghts(lst):
    tempLst = []
    for ele in lst:
        tempLst.append(len(ele))
    return tempLst

getList = udf(lambda data:data.split(),StringType())
getListLen = udf(getLenghts,ArrayType(IntegerType()))
getMaxLen = udf(lambda data:max(data),IntegerType())

df = (df.withColumn('splitWords',getList(df.value))
        .withColumn('lengthList',getListLen(col('splitWords')))
        .withColumn('maxLen',getMaxLen('lengthList')))
df.filter(df.maxLen<5).select('value').show()




+----------------+
|           value|
+----------------+
|  athshgthsc asl|
|sdf sdfdsadf sdf|
|     arasdfa sdf|
|           aa bb|
|     aaa bbb ccc|
|       dd aa bbb|
+----------------+

+----------------+--------------------+----------+------+
|           value|          splitWords|lengthList|maxLen|
+----------------+--------------------+----------+------+
|  athshgthsc asl|   [athshgthsc, asl]|   [10, 3]|    10|
|sdf sdfdsadf sdf|[sdf, sdfdsadf, sdf]| [3, 8, 3]|     8|
|     arasdfa sdf|      [arasdfa, sdf]|    [7, 3]|     7|
|           aa bb|            [aa, bb]|    [2, 2]|     2|
|     aaa bbb ccc|     [aaa, bbb, ccc]| [3, 3, 3]|     3|
|       dd aa bbb|       [dd, aa, bbb]| [2, 2, 3]|     3|
+----------------+--------------------+----------+------+

+-----------+
|      value|
+-----------+
|      aa bb|
|aaa bbb ccc|
|  dd aa bbb|
+-----------+

可以修改以保持长度&gt; 15.在分割数据集之前,还可以执行更多的预处理。对我来说,我保持长度&gt; 5被过滤掉。

答案 1 :(得分:0)

虽然之前的答案看似正确,但我认为您可以使用简单的用户定义函数来完成此操作。创建函数以拆分字符串并找到长度为&gt;的任何单词。 15:

def no_long_words(s):
    for word in s.split():
        if len(word) > 15:
            return False
    return True

创建udf:

from pyspark.sql.types import BooleanType
no_long_words_udf = udf(no_long_words, BooleanType())

使用udf:

在数据框上运行过滤器
df2 = df.filter(no_long_words_udf('col1'))
df2.show()

+--------------------+
|                col1|
+--------------------+
|Alice went to won...|
|qwertyuiopqwert s...|
|         hello world|
+--------------------+

注意:qwertyuiopqwert实际上是15个字符长,所以它包含在结果中。