I want to add a new column to each row of my DataFrame and need to do some calculation over the whole DataFrame for this. Is this possible?
I have a DataFrame that contains a three dimensional index (to locate a cell in a grid) and a value per row:
+---+-----+----+---+
| t| x| y| n|
+---+-----+----+---+
| 6|-4008|7321| 5|
| 7|-4007|7317| 1|
| 6|-4008|7323| 2|
| 6|-4008|7324| 17|
| 7|-4007|7326| 16|
+---+-----+----+---+
…
Now I want to add a fifth column (w) that holds the n
values of the neighboring cells for each row.
Example:
+---+-----+----+---+---+
| t| x| y| n| w|
+---+-----+----+---+---+
| 6|-4008|7318| 5| 1|
| 7|-4007|7317| 1| 5|
+---+-----+----+---+---+
…
For this I have to locate the current cell in the grid (easily done via t, x, y
), find the cells surrounding it and concatenate count and sum up their n
values. The result of this would be what I want to insert into the w
column.
Now how do I perform this calculation for every single row, when I need access to the whole DataFrame for it?
My initial idea was to something like:
// df is my DataFrame that holds t, x, y and n
df.map(row => Row(row(0), row(1), row(2), row(3), calculateW(df, row)))
But I get an NPE every time I try to access df
inside calculateW()
Since I am new to thinking in Spark I came up with the following solution. It works but is terribly slow.
def udfCalcZ = udf(...)
def udfCalcP = udf(...)
val dfWithPandZ = df.as("a")
.join(
df.as("b"),
$"b.t".geq($"a.t" - 1) && $"b.t".leq($"a.t" + 1)
&& $"b.y".geq($"a.y" - 1) && $"b.y".leq($"a.y" + 1)
&& $"b.x".geq($"a.x" - 1) && $"b.x".leq($"a.x" + 1)
)
.groupBy($"a.x", $"a.y", $"a.t", $"a.n").agg(count($"b.n").as("wLength"), sum($"b.n").as("wSum"))
.withColumn("zscore", udfCalcZ($"wLength", $"wSum"))
.withColumn("pvalue", udfCalcP($"zscore"))
答案 0 :(得分:-2)
if you need whole data frame , then you can collect the dataframe , like:
val df = yourDataFrame
val df1 = df.collect()
df.map(row => Row(row(0), row(1), row(2), row(3), calculateW(df1, row)))
although it would be computation overhead, and you need to modify your calculateW() function. It seems suitable for your use case.