使用R< dplyr,我会像这样计算组间的差异:
df %>% group_by(group) %>% summarise(total = sum(value)) %>% summarise(variance_between_groups = var(total))
尝试使用Sparks DataFrame API执行相同的操作:
df.groupBy(group).agg(sum(value).alias("total")).agg(var_samp(total).alias("variance_between_groups"))
我在第二个agg
中收到错误,说它无法找到total
。我显然是在误解某些东西,所以任何帮助都会受到赞赏。
答案 0 :(得分:1)
var_samp()
采用String类型的列名,因此您需要提供如下字符串:
import org.apache.spark.sql.functions._
val df = Seq(
("a", 1.0),
("a", 2.5),
("a", 1.5),
("b", 2.0),
("b", 1.6)
).toDF("group", "value")
df.groupBy("group").
agg(sum("value").alias("total")).
agg(var_samp("total").alias("variance_between_groups")).
show
// +-----------------------+
// |variance_between_groups|
// +-----------------------+
// | 0.9799999999999999|
// +-----------------------+
它也可以采用列(列类型),例如var_samp($"total")
。有关更多详细信息,请参阅Spark的API doc。