提取对角元素(如matlab中的spdiags)

时间:2014-08-22 10:14:44

标签: c++ c matlab sparse-matrix

我想尽可能在​​C中实现spdiags功能。

(我更喜欢C到C ++而我现在不想使用C ++算法)

作为输入矩阵:

inMx =
     1     0     0
     4     5     6
     0     7     9

你应该获得(在Matlab中使用spdiags运行它):

ouMx =
     4     1     0
     7     5     0
     0     9     6

(但有一件事我能理解,即使在文档中它说如果你在主对角线下方也会在顶部插入零,在这里我们可以看到相反的情况,但在链接中的例子确定

使用下面的代码,我将作为输出:

ouMx =
     4     7     0
     7     5     9
     0     9     6

所以我觉得我很亲密!

我将输出矩阵归零,而不必在列的底部或顶部插入零。 但我无法完成主对角线上方/下方的处理。

我使用if ( j > i ) swap rows,但它不起作用,所以我只使用交换行。

(我假设这个例子有方阵,但它适用于任何矩阵)

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>

void Diag( int Rows , int Cols , float * inMx , float * ouMx );
void swapRows( int Rows , int Cols , float * Mx );

int main( int argc, const char* argv[] ){

    int Rows = 3 , Cols = 3;  
    float *inMx = (float *) malloc ( Rows * Cols * sizeof (float) );
    float *ouMx = (float *) malloc ( Rows * Cols * sizeof (float) );


    // assume row major order
    inMx[0] = 1.0;
    inMx[1] = 0.0;
    inMx[2] = 0.0;
    inMx[3] = 4.0;
    inMx[4] = 5.0;
    inMx[5] = 6.0;
    inMx[6] = 0.0;
    inMx[7] = 7.0;
    inMx[8] = 9.0;


    // print  input matrix ( row major )
    printf("\n      Input matrix     \n\n");
    for ( int i = 0; i < Rows; i++ )
        for ( int j = 0; j < Cols; j++ ) {
            printf("%f\t",inMx[ i * Cols + j ]);
            if( j == Cols-1 )
                printf("\n");
            }
    printf("\n");

    // extract diagonals
    Diag( Rows , Cols , inMx , ouMx );

    // print Diagonal matrix 
    printf("\n      Diagonal matrix     \n\n");
    for ( int i = 0; i < Rows; i++ )
        for (int j = 0; j < Cols; j++ ) {
            printf("%f\t",ouMx[ i * Cols + j ]);
            if( j == Cols-1 )
                printf("\n");
            }

    printf("\n");


    free( inMx );
    free( ouMx );


  return 0;
}


void Diag( int Rows , int Cols , float * inMx , float * ouMx )
{

    //zero out ouMx
    memset( ouMx , 0 , Rows * Cols * sizeof(float) );

    // scan from the last line to the first -1  for each column
    for ( int j = 0; j < Cols; j++ )
    {
        for ( int i = ( Rows - 1 ); i > 0 ; i-- ) 
        {
            // neglect the zero elements
            if ( inMx[ i * Cols + j ] != 0 )
            {
                ouMx[ i * Cols + j ] = inMx[ i * Cols + j ];

                //if the element in the next colulmn is !=0
                if ( inMx[ ( i + 1 ) * Cols + ( j + 1 ) ] != 0 )
                {
                    ouMx[ ( i + 1 ) * Cols + j ] = inMx[ ( i + 1 ) * Cols + ( j + 1 ) ];

                }

            }
            //if we are above the main diagonal
            //swap elements of a row (in each column) in order to have the zeros at bottom/top
            // if ( i > j ) doesn't work
            swapRows( Rows , Cols , ouMx );
        }


    }


}


void swapRows( int Rows , int Cols , float * Mx )
{

    float temp;

    for ( int j = 0; j < Cols; j++ )
    { 
        for ( int i = ( Rows - 1 ); i > 0 ; i-- ) 
        {

            temp = Mx[ ( i - 1 ) * Cols + j ];
            Mx[ ( i - 1 ) * Cols + j ] = Mx[ i * Cols + j ];
            Mx[ i * Cols + j ] = temp;

        }
    }    

}

0 个答案:

没有答案