计算Spark / Scala中数据帧的所有行的置信区间平均值

时间:2017-07-05 10:10:41

标签: scala apache-spark apache-spark-sql

我需要计算我的数据帧的置信区间,最大置信区间和最小置信区间超过value3列的平均值,我需要将它应用于我的所有数据帧。这是我的数据框:

+--------+---------+------+
|  value1| value2  |value3|
+--------+---------+------+
|   a    |  2      |   3  |
+--------+---------+------+
|   b    |  5      |   4  |
+--------+---------+------+
|   b    |  5      |   4  |
+--------+---------+------+
|   c    |  3      |   4  |
+--------+---------+------+ 

所以我的输出应该如下所示(x是计算的结果):

    +--------+---------+------+-------+--------+----------+
    |  value1| value2  |value3|max_int|min_int |    int   |      |
    +--------+---------+------+-------+--------+----------+
    |   a    |  2      |   3  |   x   |   x    |     x    |
    +--------+---------+------+-------+--------+----------+
    |   b    |  5      |   4  |   x   |   x    |     x    |
    +--------+---------+------+-------+--------+----------+
    |   b    |  5      |   4  |   x   |   x    |     x    |
    +--------+---------+------+-------+--------+----------+
    |   c    |  3      |   4  |   x   |   x    |     x    |
    +--------+---------+------+-------+--------+----------+

由于我无法找到内置函数,所以我找到了以下函数来做到这一点。这是计算它的代码。

    import org.apache.commons.math3.distribution.TDistribution
    import org.apache.commons.math3.exception.MathIllegalArgumentException
    import org.apache.commons.math3.stat.descriptive.SummaryStatistics
    import scala.collection.JavaConversions._

    object ConfidenceIntervalApp {

      def main(args: Array[String]): Unit = {

    ///my dataframe name is df

        }
    // Calculate 95% confidence interval
        val ci: Double = calcMeanCI(stats, 0.95)
        println(String.format("Mean: %f", stats.getMean))
        val lower: Double = stats.getMean - ci
        val upper: Double = stats.getMean + ci

      }
      def calcMeanCI(stats:Rdd, level: Double): Double =
        try {
    // Create T Distribution with N-1 degrees of freedom
          val tDist: TDistribution = new TDistribution(stats.getN - 1)
    // Calculate critical value
          val critVal: Double =
            tDist.inverseCumulativeProbability(1.0 - (1 - level) / 2)
    // Calculate confidence interval
          critVal * stats.getStandardDeviation / Math.sqrt(stats.getN)
        } catch {
          case e: MathIllegalArgumentException => java.lang.Double.NaN

        }

}

你能帮忙或至少告诉我如何在栏目上应用它。提前致谢。

你能帮助我吗?

1 个答案:

答案 0 :(得分:1)

你可以做点什么

val cntInterval = df.select("value3").rdd.countApprox(timeout = 1000L,confidence = 0.95)
val (lowCnt,highCnt) = (cntInterval.getFinalValue().low, cntInterval.getFinalValue().high)

df.withColumn("max_int", lit(highCnt))
  .withColumn("min_int", lit(lowCnt))
  .withColumn("int", lit(cntInterval.getFinalValue().toString()))
  .show(false)

我从In spark, how to estimate the number of elements in a dataframe quickly

获得了帮助