过滤Spark Dataframe中的列以查找每个元素的百分比

时间:2016-10-16 14:12:40

标签: python filtering pyspark spark-dataframe pyspark-sql

我正在尝试使用pyspark过滤Spark Dataframe中的列,我想知道哪些记录占总列数的10%或更少,

例如,我的DataFrame中有以下标题为“Animal”的列:

动物

  • 大鼠

要找到记录“鼠”,我试过

df.filter(df.groupBy("Animal").count() <= 0.1 * df.select("Animal").count()).collect()

我收到以下错误“TypeError:condition应该是字符串或列”

如何找到代表低于10%的记录?

PS:在SQL中会更简单吗?

类似的东西:

result = spark.sql("SELECT Animal, COUNT(ANIMAL) FROM Table HAVING COUNT(Animal) < 0.1 * COUNT(Animal))

我知道这是一个简单的操作,但我无法弄清楚如何编码 占总数的10%。

感谢您的帮助!

1 个答案:

答案 0 :(得分:1)

首先必须计算总数,然后在第二步中使用它来过滤。

在浓缩代码(pyspark,spark 2.0)中:

import pyspark.sql.functions as F
df=sqlContext.createDataFrame([['Cat'],['Cat'],['Dog'],['Dog'],
    ['Cat'],['Cat'],['Dog'],['Dog'],['Cat'],['Rat']],['Animal'])
total=df.count()
result=(df.groupBy('Animal').count()
    .withColumn('total',F.lit(total))
    .withColumn('fraction',F.expr('count/total'))
    .filter('fraction>0.1'))
result.show()

给出结果:

+------+-----+-----+--------+
|Animal|count|total|fraction|
+------+-----+-----+--------+
|   Dog|    4|   10|     0.4|
|   Cat|    5|   10|     0.5|
+------+-----+-----+--------+

过滤初始设置:

filtered=df.join(result,df.Animal==result.Animal,'leftsemi')
filtered.show()

'leftsemi'连接将记录保存在结果

中具有匹配键的df中