如何从pyspark的每隔一行中减去spark数据帧中的每一行?

时间:2019-04-03 05:58:55

标签: python apache-spark pyspark apache-spark-sql

我有一个Spark数据框,其中有3列,分别指示原子的位置,即位置X,Y和Z。现在要查找需要应用距离公式的每2个原子之间的距离。 距离公式为d = sqrt((x2-x1)^ 2 +(y2-y1)^ 2 +(z2-z1)^ 2)

因此要应用上述公式,我需要从x的每隔一行中减去x中的每一行,从y中的每隔一行中减去y的每一行,以此类推。然后对每两个原子应用上述公式。

我试图制作一个用户定义的函数(udf),但是我无法将整个spark数据帧传递给它,我只能分别传递每一列,而不是整个数据帧。由于这个原因,我无法遍历整个数据框,因此我必须在每一列上申请循环。下面的代码显示了我仅对Position_X所做的迭代。

@udf
def Distance(Position_X,Position_Y, Position_Z):
    try:
       for x,z in enumerate(Position_X) :
           firstAtom = z
           for y, a in enumerate(Position_X):
                if (x!=y):
                    diff = firstAtom - a
           return diff
    except:
        return None

newDF1 = atomsDF.withColumn("Distance", Distance(*atomsDF.columns))

My atomDF spark dataframe look like this, each row shows the x,y,z coordinates of one atom in space. Right now we are taking only 10 atoms.

Position_X|Position_Y|Position_Z|
+----------+----------+----------+
|    27.545|     6.743|    12.111|
|    27.708|     7.543|    13.332|
|    27.640|     9.039|    12.970|
|    26.991|     9.793|    13.693|
|    29.016|     7.166|    14.106|
|    29.286|     8.104|    15.273|
|    28.977|     5.725|    14.603|
|    28.267|     9.456|    11.844|
|    28.290|    10.849|    11.372|
|    26.869|    11.393|    11.161|
+----------+----------+----------+

如何在pyspark i-e中解决上述问题。如何从每隔一行中减去每一行?如何将整个spark数据框传递给udf而不是其列?以及如何避免使用太多for循环?

每两个原子(行)的预期输出将是使用上述距离公式计算的两行之间的距离。我不需要保留该距离,因为我将使用它来表示势能。或者,如果我可以将其保留在单独的数据框中,则不介意。

1 个答案:

答案 0 :(得分:1)

我想比较2到2进行交叉连接所需的原子(线),不建议这样做。

您可以使用函数monotonically_increasing_id为每一行生成一个ID。

from pyspark.sql import functions as F
df = df.withColumn("id", F.monotonically_increasing_id())

然后,您将数据框与自身交叉连接,并使用“ id_1> id_2”

行进行过滤
df_1 = df.select(*(F.col(col).alias("{}_1".format(col)) for col in df.columns))
df_2 = df.select(*(F.col(col).alias("{}_2".format(col)) for col in df.columns))
df_3 = df_1.crossJoin(df_2).where("id_1 > id_2")

df_3包含您需要的45行。您只需要应用公式:

df_4 = df_3.withColumn(
    "distance",
    F.sqrt(
        F.pow(F.col("Position_X_1") - F.col("Position_X_2"), F.lit(2))
        + F.pow(F.col("Position_Y_1") - F.col("Position_Y_2"), F.lit(2))
        + F.pow(F.col("Position_Z_1") - F.col("Position_Z_2"), F.lit(2))
    )
)


df_4.orderBy('id_2', 'id_1').show()
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
|Position_X_1|Position_Y_1|Position_Z_1|      id_1|Position_X_2|Position_Y_2|Position_Z_2|id_2|          distance|
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
|      27.708|       7.543|      13.332|         1|      27.545|       6.743|      12.111|   0|1.4688124454810418|
|       27.64|       9.039|       12.97|         2|      27.545|       6.743|      12.111|   0| 2.453267616873462|
|      26.991|       9.793|      13.693|         3|      27.545|       6.743|      12.111|   0| 3.480249991020759|
|      29.016|       7.166|      14.106|         4|      27.545|       6.743|      12.111|   0|2.5145168522004355|
|      29.286|       8.104|      15.273|8589934592|      27.545|       6.743|      12.111|   0|3.8576736513085175|
|      28.977|       5.725|      14.603|8589934593|      27.545|       6.743|      12.111|   0| 3.049100195139542|
|      28.267|       9.456|      11.844|8589934594|      27.545|       6.743|      12.111|   0|2.8200960976534106|
|       28.29|      10.849|      11.372|8589934595|      27.545|       6.743|      12.111|   0| 4.237969089080287|
|      26.869|      11.393|      11.161|8589934596|      27.545|       6.743|      12.111|   0| 4.793952023122468|
|       27.64|       9.039|       12.97|         2|      27.708|       7.543|      13.332|   1|1.5406764747993003|
|      26.991|       9.793|      13.693|         3|      27.708|       7.543|      13.332|   1|2.3889139791964036|
|      29.016|       7.166|      14.106|         4|      27.708|       7.543|      13.332|   1|1.5659083625806454|
|      29.286|       8.104|      15.273|8589934592|      27.708|       7.543|      13.332|   1|2.5636470115833037|
|      28.977|       5.725|      14.603|8589934593|      27.708|       7.543|      13.332|   1|2.5555676473143896|
|      28.267|       9.456|      11.844|8589934594|      27.708|       7.543|      13.332|   1|  2.48720606303539|
|       28.29|      10.849|      11.372|8589934595|      27.708|       7.543|      13.332|   1|  3.88715319996524|
|      26.869|      11.393|      11.161|8589934596|      27.708|       7.543|      13.332|   1| 4.498851186691999|
|      26.991|       9.793|      13.693|         3|       27.64|       9.039|       12.97|   2|1.2298154333069653|
|      29.016|       7.166|      14.106|         4|       27.64|       9.039|       12.97|   2|2.5868902180030737|
|      29.286|       8.104|      15.273|8589934592|       27.64|       9.039|       12.97|   2|2.9811658793163454|
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
only showing top 20 rows

它仅处理少量数据,但是crossJoin会占用大量数据,从而破坏性能。