How do I access the whole DataFrame in map/agg function?

时间:2016-07-11 19:23:42

标签: scala apache-spark spark-dataframe

I want to add a new column to each row of my DataFrame and need to do some calculation over the whole DataFrame for this. Is this possible?

Context

I have a DataFrame that contains a three dimensional index (to locate a cell in a grid) and a value per row:

+---+-----+----+---+
|  t|    x|   y|  n|
+---+-----+----+---+
|  6|-4008|7321|  5|
|  7|-4007|7317|  1|
|  6|-4008|7323|  2|
|  6|-4008|7324| 17|
|  7|-4007|7326| 16|
+---+-----+----+---+
…

Now I want to add a fifth column (w) that holds the n values of the neighboring cells for each row.
Example:

+---+-----+----+---+---+
|  t|    x|   y|  n|  w|
+---+-----+----+---+---+
|  6|-4008|7318|  5|  1|
|  7|-4007|7317|  1|  5|
+---+-----+----+---+---+
…

For this I have to locate the current cell in the grid (easily done via t, x, y), find the cells surrounding it and concatenate count and sum up their n values. The result of this would be what I want to insert into the w column.

Now how do I perform this calculation for every single row, when I need access to the whole DataFrame for it?


Update 1

My initial idea was to something like:

// df is my DataFrame that holds t, x, y and n
df.map(row => Row(row(0), row(1), row(2), row(3), calculateW(df, row)))

But I get an NPE every time I try to access df inside calculateW()


Update 2, containing "solution"

Since I am new to thinking in Spark I came up with the following solution. It works but is terribly slow.

def udfCalcZ = udf(...)
def udfCalcP = udf(...)

val dfWithPandZ = df.as("a")
  .join(
    df.as("b"),
    $"b.t".geq($"a.t" - 1) && $"b.t".leq($"a.t" + 1)
      && $"b.y".geq($"a.y" - 1) && $"b.y".leq($"a.y" + 1)
      && $"b.x".geq($"a.x" - 1) && $"b.x".leq($"a.x" + 1)
  )
  .groupBy($"a.x", $"a.y", $"a.t", $"a.n").agg(count($"b.n").as("wLength"), sum($"b.n").as("wSum"))
  .withColumn("zscore", udfCalcZ($"wLength", $"wSum"))
  .withColumn("pvalue", udfCalcP($"zscore"))

1 个答案:

答案 0 :(得分:-2)

if you need whole data frame , then you can collect the dataframe , like:

val df = yourDataFrame
val df1 = df.collect()

df.map(row => Row(row(0), row(1), row(2), row(3), calculateW(df1, row)))

although it would be computation overhead, and you need to modify your calculateW() function. It seems suitable for your use case.