提取PySpark中的特定行

时间:2019-04-09 14:21:07

标签: python apache-spark pyspark apache-spark-sql

我有一个这样的数据框

data = [(("ID1", "A", 1)), (("ID1", "B", 5)), (("ID2", "A", 12)), 
       (("ID3", "A", 3)), (("ID3", "B", 3)), (("ID3", "C", 5)), (("ID4", "A", 10))]
df = spark.createDataFrame(data, ["ID", "Type", "Value"])
df.show()

+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID1|   A|    1|
|ID1|   B|    5|
|ID2|   A|   12|
|ID3|   A|    3|
|ID3|   B|    3|
|ID3|   C|    5|
|ID4|   A|   10|
+---+----+-----+

我只想提取仅包含一种特定类型-“ A”的那些行(或ID)

因此,我的预期输出将包含以下行

+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2|   A|    1|
|ID4|   A|   10|
+---+----+-----+

对于每个ID可以包含任何类型-A,B,C等。我想提取那些仅包含一个且仅一个类型-'A'的ID

我如何在PySpark中实现这一目标

3 个答案:

答案 0 :(得分:4)

您可以对其应用过滤器。

import pyspark.sql.functions as f

data = [(("ID1", "A", 1)), (("ID1", "B", 5)), (("ID2", "A", 12)), 
       (("ID3", "A", 3)), (("ID3", "B", 3)), (("ID3", "C", 5)), (("ID4", "A", 10))]
df = spark.createDataFrame(data, ["ID", "Type", "Value"])
df.show()

+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID1|   A|    1|
|ID1|   B|    5|
|ID2|   A|   12|
|ID3|   A|    3|
|ID3|   B|    3|
|ID3|   C|    5|
|ID4|   A|   10|
+---+----+-----+

x= df.filter(f.col('Type')=='A')

x.show()

如果我们需要过滤只有一条记录并且Type也为'A'的所有ID,那么下面的代码可能是解决方案


df.registerTempTable('table1')


sqlContext.sql('select a.ID, a.Type,a.Value from table1 as a, (select ID, count(*) as cnt_val from table1 group by ID) b where a.ID = b.ID and (a.Type=="A" and b.cnt_val ==1)').show()


+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2|   A|   12|
|ID4|   A|   10|
+---+----+-----+


会有更好的替代方法来找到它们。

答案 1 :(得分:4)

根据OP的要求,我记下了我在评论中写的答案。

当前问题的目的是过滤掉DataFrame,其中每个特定的ID仅包含一个元素Type A,而没有其他元素。

# Loading the requisite packages
from pyspark.sql.functions import col, collect_set, array_contains, size, first

这个想法是首先用DataFrame aggregate() ID,然后我们将collect_set()中的unique的所有Type元素归类为collect_set()数组。拥有unique元素很重要,因为对于一个特定的ID可能会有两行,而这两行中的TypeA。这就是为什么我们应该使用collect_list()而不是first()的原因,因为后者不会返回唯一元素,而是所有元素。

然后,我们应该使用array_contains()获取组中TypeValue的第一个值。如果A是特定unique唯一的Type ID的情况,则first()将返回A的唯一值A出现一次,如果A重复则出现最高值。

df = df = df.groupby(['ID']).agg(first(col('Type')).alias('Type'),
                                 first(col('Value')).alias('Value'),
                                 collect_set('Type').alias('Type_Arr'))
df.show()
+---+----+-----+---------+
| ID|Type|Value| Type_Arr|
+---+----+-----+---------+
|ID2|   A|   12|      [A]|
|ID3|   A|    3|[A, B, C]|
|ID1|   A|    1|   [A, B]|
|ID4|   A|   10|      [A]|
+---+----+-----+---------+

最后,我们将同时放置2个条件以过滤出所需的数据集。

条件1:使用size检查A数组中Type的存在。

条件2:它检查数组的this。如果大小大于1,则应该有多个Types

df = df.where(array_contains(col('Type_Arr'),'A') & (size(col('Type_Arr'))==1)).drop('Type_Arr')
df.show()
+---+----+-----+
| ID|Type|Value|
+---+----+-----+
|ID2|   A|   12|
|ID4|   A|   10|
+---+----+-----+

答案 2 :(得分:3)

我不太懂Python,这是Scala中的一种可能的解决方案:

df.groupBy("ID").agg(collect_set("Type").as("Types"))
  .select("ID").where((size($"Types")===1).and(array_contains($"Types", "A"))).show()
+---+
| ID|
+---+
|ID2|
|ID4|
+---+

这个想法是根据ID进行分组,然后仅过滤大小为1的Types,其中包含A值。