Rcpp:检索和替换方矩阵的非对角线值

时间:2018-08-13 15:08:23

标签: r rcpp armadillo

使用Rcpp / Armadillo,如何有效地提取/替换方阵的非对角线值?在R中,可以使用:old_values = A[row(A) == (col(A) - k)]; A[row(A) == (col(A) - k)] = new_values。使用Armadillo,可以使用for循环(请参见下文)实现此目标。但是,有没有更简单的方式编写代码?由于我需要对一个大型矩阵(> 10000行,> 10000列)的所有k执行此操作,因此最好考虑效率。这是一个可重现的示例:

A = matrix(1:25, 5, 5) 

A[row(A) == (col(A) - 3)] # extract the 3rd off-diagnal values
A[row(A) == (col(A) - 2)] = -5 of # replace the 2nd off-diagnal values with -5

使用for循环的cpp代码:

arma::vec retrieve_off_diag_values( arma::mat A, unsigned k )
    {
        unsigned n_cols = A.n_cols;
        arma::vec off_diag_values(n_cols - k);
        for( unsigned i=0; i <(n_cols - k); i++ )
        {
            off_diag_values(i) = A(i, i+k);
        } 
        return off_diag_values;
    } 

2 个答案:

答案 0 :(得分:1)

检索非对角线

您可以使用犰狳的.diag()成员函数和索引k来检索非对角线。

情况:

  • 如果k == 0(默认),则是主对角线。
  • 否则为k < 0,则为下三角形对角线。
  • 然后是上对角线三角形。

示例:

#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::vec offdiag_extract(arma::mat& A, int k) {
  return A.diag(k);
}

测试:

A = matrix(1:25, 5, 5) 
offdiag_extract(A, 3)
#      [,1]
# [1,]   16
# [2,]   22

替换对角线

编辑:由于@mtall关于其他成员函数提供的行为的观点,本节已更新。

但是,.diag()=仅可用于保存到主对角线中。为了确保替换对所有对角线都可行,您需要将.diag()成员函数与.fill(value)链接起来,例如

#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::mat offdiag_fill_arma(arma::mat& A, int k, double replace_value) {
    A.diag(k).fill(replace_value);
    return A;
}

测试:

offdiag_fill_arma(A, 2, 4)
#      [,1] [,2] [,3] [,4] [,5]
# [1,]    1    6    4   16   21
# [2,]    2    7   12    4   22
# [3,]    3    8   13   18    4
# [4,]    4    9   14   19   24
# [5,]    5   10   15   20   25

实施非对角线替换

简而言之,可以使用具有适当for偏移的单个k循环来实现非对角线替换。

#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::mat offdiag_replace(arma::mat& A, int k, double replace_val = -5) {

  // Determine whether to go over upper or lower diagonal  
  unsigned int row_offset = (k < 0) ? -k : 0;
  unsigned int col_offset = (k > 0) ?  k : 0;

  // Compute total number of elements
  unsigned int N = std::min(A.n_rows - row_offset, A.n_cols - col_offset);

  // Loop over diagonal
  for(unsigned int i = 0; i < N; ++i) {

    unsigned int row = i + row_offset;
    unsigned int col = i + col_offset;

    // Disregard bounds checks with .at()
    A.at(row,col) = replace_val;
  }

  return A;
}

测试:

offdiag_replace(A, 2, 4)
#      [,1] [,2] [,3] [,4] [,5]
# [1,]    1    6    4   16   21
# [2,]    2    7   12    4   22
# [3,]    3    8   13   18    4
# [4,]    4    9   14   19   24
# [5,]    5   10   15   20   25

答案 1 :(得分:1)

要将指定对角线上的值提取到向量中,其中k <0表示子对角线,k = 0表示主对角线,而k> 0表示超对角线:

#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::vec diag_get(const arma::mat& X, int k)   // note the 'const' and '&'
{
    return X.diag(k);
}

要将对角线上的值设置为特定值:

// [[Rcpp::export]]
void diag_fill(arma::mat& X, int k, double value)   // note the '&' character
{
    X.diag(k).fill(value);
}

要在对角线上将特定值的实例更改为另一个值:

// [[Rcpp::export]]
void diag_change(arma::mat& X, int k, double old_value, double new_value)
{
    X.diag(k).replace(old_value, new_value);
}