如何在Scala Breeze中进行X * diag(Y)?

时间:2012-09-29 08:18:32

标签: scala linear-algebra scalala

如何在Scala Breeze中X * diag(Y)X可以是CSCMatrixY可以是DenseVector

在MATLAB语法中,这将是:

X * spdiags(0, Y, N, N )

或者:

X .* repmat( Y', K, 0 )

在SciPy语法中,这将是'广播乘法':

Y * X

Scala Breeze中如何X * diag(Y)

1 个答案:

答案 0 :(得分:2)

我最后编写了自己的稀疏对角线方法和密集/稀疏乘法。

像这样使用:

val N = 100000
val K = 100
val A = DenseMatrix.rand(N,K)
val b = DenseVector.rand(N)
val c = MatrixHelper.spdiag(b)
val d = MatrixHelper.mul( A.t, c )

以下是spdiag和mul的实现:

// Copyright Hugh Perkins 2012
// You can use this under the terms of the Apache Public License 2.0
// http://www.apache.org/licenses/LICENSE-2.0

package root

import breeze.linalg._

object MatrixHelper {
   // it's only efficient to put the sparse matrix on the right hand side, since 
   // it is a column-sparse matrix
   def mul( A: DenseMatrix[Double], B: CSCMatrix[Double] ) : DenseMatrix[Double] = {
      val resultRows = A.rows
      val resultCols = B.cols
      var row = 0
      val result = DenseMatrix.zeros[Double](resultRows, resultCols )
      while( row < resultRows ) {
         var col = 0
         while( col < resultCols ) {
            val rightRowStartIndex = B.colPtrs(col)
            val rightRowEndIndex = B.colPtrs(col + 1) - 1
            val numRightRows = rightRowEndIndex - rightRowStartIndex + 1
            var ri = 0
            var sum = 0.
            while( ri < numRightRows ) {
               val inner = B.rowIndices(rightRowStartIndex + ri)
               val rightValue = B.data(rightRowStartIndex + ri)
               sum += A(row,inner) * rightValue
               ri += 1
            }
            result(row,col) = sum
            col += 1
         }
         row += 1
      }
      result
   }   

   def spdiag( a: Tensor[Int,Double] ) : CSCMatrix[Double] = {
      val size = a.size
      val result = CSCMatrix.zeros[Double](size,size)
      result.reserve(a.size)
      var i = 0
      while( i < size ) {
         result.rowIndices(i) = i
         result.colPtrs(i) = i
         result.data(i) = i
         //result(i,i) = a(i)
         i += 1
      }
      //result.activeSize = size
      result.colPtrs(i) = i
      result
   }
}