在Scala中将Matrix的对角线值替换为1

时间:2018-07-10 04:56:19

标签: scala apache-spark apache-spark-mllib

我有一个mllib矩阵,

mat: org.apache.spark.mllib.linalg.Matrix =
0.0  2.0  1.0  2.0
2.0  0.0  2.0  4.0
1.0  2.0  0.0  3.0
2.0  4.0  3.0  0.0

根据我的scala代码,matrix的对角线为0.0。我需要像

一样将此值替换为1.0
mat: org.apache.spark.mllib.linalg.Matrix =
1.0  2.0  1.0  2.0
2.0  1.0  2.0  4.0
1.0  2.0  1.0  3.0
2.0  4.0  3.0  1.0

我该如何实现?请为我提供优化的解决方案。

1 个答案:

答案 0 :(得分:1)

要更改对角线,您需要将矩阵转换为行迭代器,以便可以对其进行迭代,然后使用索引压缩该迭代器,以根据索引替换每一行的元素,这是矩阵的对角元素。下面是带有必需注释的代码。

import org.apache.spark.mllib.linalg.DenseMatrix

//creating initial matrix which needs to be changes
val arr = Array(0.0,2.0,1.0,2.0,2.0,0.0,2.0,4.0,1.0,2.0,0.0,3.0,2.0,4.0,3.0,0.0)
val mat = new DenseMatrix(4,4,arr)

//output
//0.0  2.0  1.0  2.0
//2.0  0.0  2.0  4.0
//1.0  2.0  0.0  3.0
//2.0  4.0  3.0  0.0

//make the iterator and change the element at the index of each row
val changedArr = mat.rowIter.zipWithIndex.flatMap(x => {
  val ar = x._1.toArray
  ar(x._2) = 1.0
  ar
}).toArray

//create new matrix from it
val changedMat = new DenseMatrix(mat.numRows, mat.numCols, changedArr)

//output
//1.0  2.0  1.0  2.0
//2.0  1.0  2.0  4.0
//1.0  2.0  1.0  3.0
//2.0  4.0  3.0  1.0