我正在尝试使用pyspark过滤Spark Dataframe中的列,我想知道哪些记录占总列数的10%或更少,
例如,我的DataFrame中有以下标题为“Animal”的列:
动物
要找到记录“鼠”,我试过
df.filter(df.groupBy("Animal").count() <= 0.1 * df.select("Animal").count()).collect()
我收到以下错误“TypeError:condition应该是字符串或列”
如何找到代表低于10%的记录?
PS:在SQL中会更简单吗?类似的东西:
result = spark.sql("SELECT Animal, COUNT(ANIMAL) FROM Table HAVING COUNT(Animal) < 0.1 * COUNT(Animal))
我知道这是一个简单的操作,但我无法弄清楚如何编码 占总数的10%。
感谢您的帮助!
答案 0 :(得分:1)
首先必须计算总数,然后在第二步中使用它来过滤。
在浓缩代码(pyspark,spark 2.0)中:
import pyspark.sql.functions as F
df=sqlContext.createDataFrame([['Cat'],['Cat'],['Dog'],['Dog'],
['Cat'],['Cat'],['Dog'],['Dog'],['Cat'],['Rat']],['Animal'])
total=df.count()
result=(df.groupBy('Animal').count()
.withColumn('total',F.lit(total))
.withColumn('fraction',F.expr('count/total'))
.filter('fraction>0.1'))
result.show()
给出结果:
+------+-----+-----+--------+
|Animal|count|total|fraction|
+------+-----+-----+--------+
| Dog| 4| 10| 0.4|
| Cat| 5| 10| 0.5|
+------+-----+-----+--------+
过滤初始设置:
filtered=df.join(result,df.Animal==result.Animal,'leftsemi')
filtered.show()
'leftsemi'连接将记录保存在结果
中具有匹配键的df中