我在计算标准偏差(stddev)时得到NaN。这是一个非常简单的用例,如下所述:
val df = Seq(("1",19603176695L),("2", 26438904194L),("3",29640527990L),("4",21034972928L),("5", 23975L)).toDF("v","data")
我将stddev定义为UDF:
def stddev(col: Column) = {
sqrt(mean(col*col) - mean(col)*mean(col))
}
当我调用UDF时,我得到NaN
,如下所示:
df.agg(stddev(col("data")).as("stddev")).show()
它产生以下内容:
+------+
|stddev|
+------+
| NaN|
+------+
我做错了什么?
答案 0 :(得分:3)
根据您的数据,mean(col*col)
和mean(col)*mean(col)
都会大于Long
的最大值。您可以先尝试将输入列投射到double
:
df.agg(stddev(col("data").cast("double")).as("stddev"))
但总的来说,在非常大的数字上它不会特别稳定。