pyspark将列值与另一列进行比较包含值范围

时间:2017-09-03 10:39:45

标签: python apache-spark pyspark

我想将一列的值与另一列具有参考值范围的列进行比较。

我尝试使用以下代码:

from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
df1 = sc.parallelize([([1], [1, 2, 3]), ([2], [4, 5, 6,7])]).toDF(["value", "Reference_value"])
intersect = lambda type: (udf(
    lambda x, y: (
        list(set(x) & set(y)) if x is not None and y is not None else None),
    ArrayType(type)))
integer_intersect = intersect(IntegerType())

# df1.select(
#     integer_intersect("value", "Reference_value"),
#     size(integer_intersect("value", "Reference_value"))).show()
df1=df1.where(size(integer_intersect("value", "Reference_value")) > 0)
df1.show()   

如果我们创建如下的数据帧,则上述代码有效:

因为值和refernce_value列是带有long_type的array_type 但如果我正在使用csv读取数据帧,那么我无法转换为数组类型。这里df1是从CSV

中读取的
df1 is as follows df1=

category    value   Reference value

count          1        1
n_timer       n20       n40,n20
frames         54       56
timer          n8      n3,n6,n7
pdf          FALSE      TRUE
zip          FALSE      FALSE

我想比较"价值"列" Reference_value"列和派生两个新数据帧,其中一个数据帧是在值列不在参考值集中时过滤行。

输出df2 =

  category      value      Reference value

    count          1        1
    n_timer       n20       n40,n20
    zip          FALSE      FALSE

输出df3 =

category    value   Reference value

frames         54       56
timer          n8      n3,n6,n7
pdf          FALSE      TRUE

有没有像array_contains更简单的方法。我尝试了Array_contains,但没有工作

from pyspark.sql.functions import array_contains
df.where(array_contains("Reference_value", df1["vale"]))

1 个答案:

答案 0 :(得分:-2)

#One can copy paste the below code for direct input and outputs

from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import Row
from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
from pyspark.sql.functions import split 
sc = SparkContext.getOrCreate()
sqlContext = SQLContext.getOrCreate(sc)
df1 = sc.parallelize([("count","1","1"), ("n_timer","n20","n40,n20"), ("frames","54","56"),("timer","n8","n3,n6,n7"),("pdf","FALSE","TRUE"),("zip","FALSE","FALSE")]).toDF(["category", "value","Reference_value"])
print(df1.show())
df1=df1.withColumn("Reference_value", split("Reference_value", ",\s*").cast("array<string>"))
df1=df1.withColumn("value", split("value", ",\s*").cast("array<string>"))
intersect = lambda type: (udf(
    lambda x, y: (
        list(set(x) & set(y)) if x is not None and y is not None else None),
    ArrayType(type)))
string_intersect = intersect(StringType())
df2=df1.where(size(string_intersect("value", "Reference_value")) > 0)
df3=df1.where(size(string_intersect("value", "Reference_value")) <= 0)
print(df2.show())
print(df3.show())

 input df1=
+--------+-----+---------------+
|category|value|Reference_value|
+--------+-----+---------------+
|   count|    1|              1|
| n_timer|  n20|        n40,n20|
|  frames|   54|             56|
|   timer|   n8|       n3,n6,n7|
|     pdf|FALSE|           TRUE|
|     zip|FALSE|          FALSE|
+--------+-----+---------------+

df2=
+--------+-------+---------------+
|category|  value|Reference_value|
+--------+-------+---------------+
|   count|    [1]|            [1]|
| n_timer|  [n20]|     [n40, n20]|
|     zip|[FALSE]|        [FALSE]|
+--------+-------+---------------+

df3=
+--------+-------+---------------+
|category|  value|Reference_value|
+--------+-------+---------------+
|  frames|   [54]|           [56]|
|   timer|   [n8]|   [n3, n6, n7]|
|     pdf|[FALSE]|         [TRUE]|
+--------+-------+---------------+