我想将一列的值与另一列具有参考值范围的列进行比较。
我尝试使用以下代码:
from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
df1 = sc.parallelize([([1], [1, 2, 3]), ([2], [4, 5, 6,7])]).toDF(["value", "Reference_value"])
intersect = lambda type: (udf(
lambda x, y: (
list(set(x) & set(y)) if x is not None and y is not None else None),
ArrayType(type)))
integer_intersect = intersect(IntegerType())
# df1.select(
# integer_intersect("value", "Reference_value"),
# size(integer_intersect("value", "Reference_value"))).show()
df1=df1.where(size(integer_intersect("value", "Reference_value")) > 0)
df1.show()
如果我们创建如下的数据帧,则上述代码有效:
因为值和refernce_value列是带有long_type的array_type 但如果我正在使用csv读取数据帧,那么我无法转换为数组类型。这里df1是从CSV
中读取的df1 is as follows df1=
category value Reference value
count 1 1
n_timer n20 n40,n20
frames 54 56
timer n8 n3,n6,n7
pdf FALSE TRUE
zip FALSE FALSE
我想比较"价值"列" Reference_value"列和派生两个新数据帧,其中一个数据帧是在值列不在参考值集中时过滤行。
输出df2 =
category value Reference value
count 1 1
n_timer n20 n40,n20
zip FALSE FALSE
输出df3 =
category value Reference value
frames 54 56
timer n8 n3,n6,n7
pdf FALSE TRUE
有没有像array_contains更简单的方法。我尝试了Array_contains,但没有工作
from pyspark.sql.functions import array_contains
df.where(array_contains("Reference_value", df1["vale"]))
答案 0 :(得分:-2)
#One can copy paste the below code for direct input and outputs
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import Row
from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
from pyspark.sql.functions import split
sc = SparkContext.getOrCreate()
sqlContext = SQLContext.getOrCreate(sc)
df1 = sc.parallelize([("count","1","1"), ("n_timer","n20","n40,n20"), ("frames","54","56"),("timer","n8","n3,n6,n7"),("pdf","FALSE","TRUE"),("zip","FALSE","FALSE")]).toDF(["category", "value","Reference_value"])
print(df1.show())
df1=df1.withColumn("Reference_value", split("Reference_value", ",\s*").cast("array<string>"))
df1=df1.withColumn("value", split("value", ",\s*").cast("array<string>"))
intersect = lambda type: (udf(
lambda x, y: (
list(set(x) & set(y)) if x is not None and y is not None else None),
ArrayType(type)))
string_intersect = intersect(StringType())
df2=df1.where(size(string_intersect("value", "Reference_value")) > 0)
df3=df1.where(size(string_intersect("value", "Reference_value")) <= 0)
print(df2.show())
print(df3.show())
input df1=
+--------+-----+---------------+
|category|value|Reference_value|
+--------+-----+---------------+
| count| 1| 1|
| n_timer| n20| n40,n20|
| frames| 54| 56|
| timer| n8| n3,n6,n7|
| pdf|FALSE| TRUE|
| zip|FALSE| FALSE|
+--------+-----+---------------+
df2=
+--------+-------+---------------+
|category| value|Reference_value|
+--------+-------+---------------+
| count| [1]| [1]|
| n_timer| [n20]| [n40, n20]|
| zip|[FALSE]| [FALSE]|
+--------+-------+---------------+
df3=
+--------+-------+---------------+
|category| value|Reference_value|
+--------+-------+---------------+
| frames| [54]| [56]|
| timer| [n8]| [n3, n6, n7]|
| pdf|[FALSE]| [TRUE]|
+--------+-------+---------------+