计算Spark中的点态互信息

时间:2015-04-14 06:11:28

标签: apache-spark apache-spark-mllib

我试图计算pointwise mutual information(PMI)。

enter image description here

我有两个分别为p(x,y)和p(x)定义的RDD:

pii: RDD[((String, String), Double)]
 pi: RDD[(String, Double)]

我在编写来自RDD piipi的PMI时编写的任何代码都不是很好。我的方法是首先压平RDD pii并加入pi两次,同时按摩元组元素。

val pmi = pii.map(x => (x._1._1, (x._1._2, x._1, x._2)))
             .join(pi).values
             .map(x => (x._1._1, (x._1._2, x._1._3, x._2)))
             .join(pi).values
             .map(x => (x._1._1, computePMI(x._1._2, x._1._3, x._2)))
// pmi: org.apache.spark.rdd.RDD[((String, String), Double)]
...
def computePMI(pab: Double, pa: Double, pb: Double) = {
  // handle boundary conditions, etc
  log(pab) - log(pa) - log(pb)
}
显然,这很糟糕。有没有更好的(惯用)方式来做到这一点? 注意:我可以通过在pipii中存储log-prob来优化日志,但选择以这种方式编写以保持问题清晰。

1 个答案:

答案 0 :(得分:4)

使用broadcast将是一种解决方案。

val bcPi = pi.context.broadcast(pi.collectAsMap())
val pmi = pii.map {
  case ((x, y), pxy) =>
    (x, y) -> computePMI(pxy, bcPi.value.get(x).get, bcPi.value.get(y).get)
}

假设:pi x中有ypii。{/ p>