我有两个数据框,我需要使用具有两个连接谓词的非等距连接(即不等式连接)将它们连接在一起。
一个数据帧是直方图DataFrame[bin: bigint, lower_bound: double, upper_bound: double]
另一个数据框是观测值DataFrame[id: bigint, observation: double]
我需要确定每个观测值属于我的直方图的bin,如下所示:
observations_df.join(histogram_df,
(
(observations_df.observation >= histogram_df.lower_bound) &
(observations_df.observation < histogram_df.upper_bound)
)
)
基本上这很慢,我正在寻找有关如何使其更快运行的建议。
下面是一些示例代码来演示问题。 observations_df
包含100000行,当histogram_df
中的行数变得足够大(假设number_of_bins = 500000
)时,它将变得非常非常慢,我确定它是因为我正在执行非平等加入。如果您运行此代码,然后以number_of_rows
的值开始计算,则从低位开始,然后增加直到性能下降很明显为止
from pyspark.sql.functions import lit, col, lead
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import rand
from pyspark.sql import Window
spark = SparkSession \
.builder \
.getOrCreate()
number_of_bins = 500000
bin_width = 1.0 / number_of_bins
window = Window.orderBy('bin')
histogram_df = spark.range(0, number_of_bins)\
.withColumnRenamed('id', 'bin')\
.withColumn('lower_bound', 0 + lit(bin_width) * col('bin'))\
.select('bin', 'lower_bound', lead('lower_bound', 1, 1.0).over(window).alias('upper_bound'))
observations_df = spark.range(0, 100000).withColumn('observation', rand())
observations_df.join(histogram_df,
(
(observations_df.observation >= histogram_df.lower_bound) &
(observations_df.observation < histogram_df.upper_bound)
)
).groupBy('bin').count().head(15)
答案 0 :(得分:0)
不建议将不相等的联接用于Spark联接。通常,我会为这种操作生成一个新列作为连接键。 但是,对于您的情况,您无需合并即可确定每个观测值属于哪个直方图区间,因为可以预先计算每个区间的上下限,并且可以使用观测值来计算该区间。
您可以做的是编写一个UDF,该UDF为您找到垃圾箱并将该垃圾箱作为新列返回。 您可以参考pyspark: passing multiple dataframe fields to udf