如何计算Apache Spark中RowMatrix的反转?

时间:2015-04-30 13:48:34

标签: scala apache-spark linear-algebra distributed-computing

我有一个以RowMatrix形式的X分布式矩阵。我使用的是Spark 1.3.0。我需要能够计算X逆。

3 个答案:

答案 0 :(得分:7)

import org.apache.spark.mllib.linalg.{Vectors,Vector,Matrix,SingularValueDecomposition,DenseMatrix,DenseVector}
import org.apache.spark.mllib.linalg.distributed.RowMatrix

def computeInverse(X: RowMatrix): DenseMatrix = {
  val nCoef = X.numCols.toInt
  val svd = X.computeSVD(nCoef, computeU = true)
  if (svd.s.size < nCoef) {
    sys.error(s"RowMatrix.computeInverse called on singular matrix.")
  }

  // Create the inv diagonal matrix from S 
  val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x,-1))))

  // U cannot be a RowMatrix
  val U = new DenseMatrix(svd.U.numRows().toInt,svd.U.numCols().toInt,svd.U.rows.collect.flatMap(x => x.toArray))

  // If you could make V distributed, then this may be better. However its alreadly local...so maybe this is fine.
  val V = svd.V
  // inv(X) = V*inv(S)*transpose(U)  --- the U is already transposed.
  (V.multiply(invS)).multiply(U)
  }

答案 1 :(得分:3)

使用带有选项

的此功能时遇到问题
conf.set("spark.sql.shuffle.partitions", "12")

RowMatrix中的行被洗牌了。

这是一个适合我的更新

import org.apache.spark.mllib.linalg.{DenseMatrix,DenseVector}
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix

def computeInverse(X: IndexedRowMatrix)
: DenseMatrix = 
{
  val nCoef = X.numCols.toInt
  val svd = X.computeSVD(nCoef, computeU = true)
  if (svd.s.size < nCoef) {
    sys.error(s"IndexedRowMatrix.computeInverse called on singular matrix.")
  }

  // Create the inv diagonal matrix from S 
  val invS = DenseMatrix.diag(new DenseVector(svd.s.toArray.map(x => math.pow(x, -1))))

  // U cannot be a RowMatrix
  val U = svd.U.toBlockMatrix().toLocalMatrix().multiply(DenseMatrix.eye(svd.U.numRows().toInt)).transpose

  val V = svd.V
  (V.multiply(invS)).multiply(U)
}

答案 2 :(得分:0)

X.computeSVD返回的矩阵U的维度为mxk,其中 m 是原始(分布式)RowMatrix X的行数。人们希望 m 为大(可能大于 k ),因此如果我们希望我们的代码扩展到非常大的 m 值,则不建议在驱动程序中收集它。

我想说下面的两个解决方案都存在这个漏洞。 @ Alexander Kharlamov给出的答案调用val U = svd.U.toBlockMatrix().toLocalMatrix()来收集驱动程序中的矩阵。 @ Climbs_lika_Spyder给出的答案也是如此(顺便说一下你的昵称岩石!!),它会调用svd.U.rows.collect.flatMap(x => x.toArray)。我宁愿建议依赖分布式矩阵乘法,例如发布的here的Scala代码。