如何在矩阵中找到每列的五个第一个最大索引?

时间:2015-05-23 18:22:45

标签: scala matrix apache-spark

这是我的矩阵,我想使用Breeze在Spark和Scala中提取每列的五个第一个最大索引:

indices

  0         0.23 0.20 0.10 0.92 0.33 0.42
  1         0.10 0.43 0.23 0.15 0.22 0.12
  2         0.20 0.13 0.25 0.85 0.02 0.32
  3         0.43 0.65 0.23 0.45 0.10 0.33
  4         0.31 0.87 0.45 0.63 0.28 0.16
  5         0.12 0.84 0.33 0.45 0.56 0.83
  6         0.40 0.22 0.12 0.87 0.35 0.78
           ...

(注意:索引不在矩阵中,只是为了更好地显示问题)

和预期的输出是:

3 4 4 0 5 5
6 5 5 6 6 6
4 3 2 2 0 0
0 1 1 4 4 3
2 6 3 3 1 2

我试过了:

  for (i <- 0 until I) {
      val T = argmax(matrix(::, i))
      results(::,i) := T
    }

但它只返回第一个最大索引。

有人能帮助我吗?

1 个答案:

答案 0 :(得分:2)

我认为您可以使用Scala可以为您提供的一些函数式编程,Breeze对于在matlab中工作非常有用,但是argmax()只给出了向量中具有更大数字的索引。当然你可以这样工作,然后得到第二个更大,然后第三个...,但是在这里你有我的建议,我认为这也将有助于你的Spark代码,以便并行化和使用更大的矩阵,请阅读评论以获得解释,也可以随意进行更改,以便使用Spark获得最大功能:

package breeze

import breeze.linalg.{DenseMatrix }

/**
 * Created by anquegi on 24/05/15.
 */
object TestMatrix extends App {

  //This is a DenseMatrix from Breeze,
  // I suppose that you have something like this

  val m = DenseMatrix(
    (0.23, 0.20, 0.10, 0.92, 0.33, 0.42),
    (0.10, 0.43, 0.23, 0.15, 0.22, 0.12),
    (0.20, 0.13, 0.25, 0.85, 0.02, 0.32),
    (0.43, 0.65, 0.23, 0.45, 0.10, 0.33),
    (0.31, 0.87, 0.45, 0.63, 0.28, 0.16),
    (0.12, 0.84, 0.33, 0.45, 0.56, 0.83),
    (0.40, 0.22, 0.12, 0.87, 0.35, 0.78))

  // Let's work in a mix functional style and iterator working with columns
  // look at this example

  val a = m(::, 0) // get the firts column
    .toArray // pass to scala array for functional usage, you can use then to List
    .zipWithIndex // now you have and array like [(value0,0),(value1,1) ... (valuen,n)]
    .sortWith((x, y) => x._1 > y._1) // sort by bigger number
    .take(5) // get only 5 first numbers
    .map(x => x._2) // finally get the indexes

  //now we have to loop for each colum
  // prepare the matrix and get the Vector(indexes,Array[Int],Array[Int])

  val listsOfIndexes = for (i <- Range(0, m.cols))
    yield m(::, i).toArray
    .zipWithIndex
    .sortWith((x, y) => x._1 > y._1)
    .take(5)
    .map(x => x._2)

  //finally conver to a DenseMatrix

  val mIndex = DenseMatrix(listsOfIndexes.map(_.toArray): _*).t

  println(mIndex)

}

结果:

[info] Running breeze.TestMatrix 
3  4  4  0  5  5  
6  5  5  6  6  6  
4  3  2  2  0  0  
0  1  1  4  4  3  
2  6  3  3  1  2  
[success] Total time: 5 s, completed 24/05/2015 16:59:43