使用Rcpp / Armadillo,如何有效地提取/替换方阵的非对角线值?在R中,可以使用:old_values = A[row(A) == (col(A) - k)]
; A[row(A) == (col(A) - k)] = new_values
。使用Armadillo,可以使用for循环(请参见下文)实现此目标。但是,有没有更简单的方式编写代码?由于我需要对一个大型矩阵(> 10000行,> 10000列)的所有k
执行此操作,因此最好考虑效率。这是一个可重现的示例:
A = matrix(1:25, 5, 5)
A[row(A) == (col(A) - 3)] # extract the 3rd off-diagnal values
A[row(A) == (col(A) - 2)] = -5 of # replace the 2nd off-diagnal values with -5
使用for循环的cpp代码:
arma::vec retrieve_off_diag_values( arma::mat A, unsigned k )
{
unsigned n_cols = A.n_cols;
arma::vec off_diag_values(n_cols - k);
for( unsigned i=0; i <(n_cols - k); i++ )
{
off_diag_values(i) = A(i, i+k);
}
return off_diag_values;
}
答案 0 :(得分:1)
您可以使用犰狳的.diag()
成员函数和索引k
来检索非对角线。
情况:
k == 0
(默认),则是主对角线。k < 0
,则为下三角形对角线。示例:
#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::vec offdiag_extract(arma::mat& A, int k) {
return A.diag(k);
}
测试:
A = matrix(1:25, 5, 5)
offdiag_extract(A, 3)
# [,1]
# [1,] 16
# [2,] 22
编辑:由于@mtall关于其他成员函数提供的行为的观点,本节已更新。
但是,.diag()=
仅可用于保存到主对角线中。为了确保替换对所有对角线都可行,您需要将.diag()
成员函数与.fill(value)
链接起来,例如
#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat offdiag_fill_arma(arma::mat& A, int k, double replace_value) {
A.diag(k).fill(replace_value);
return A;
}
测试:
offdiag_fill_arma(A, 2, 4)
# [,1] [,2] [,3] [,4] [,5]
# [1,] 1 6 4 16 21
# [2,] 2 7 12 4 22
# [3,] 3 8 13 18 4
# [4,] 4 9 14 19 24
# [5,] 5 10 15 20 25
简而言之,可以使用具有适当for
偏移的单个k
循环来实现非对角线替换。
#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat offdiag_replace(arma::mat& A, int k, double replace_val = -5) {
// Determine whether to go over upper or lower diagonal
unsigned int row_offset = (k < 0) ? -k : 0;
unsigned int col_offset = (k > 0) ? k : 0;
// Compute total number of elements
unsigned int N = std::min(A.n_rows - row_offset, A.n_cols - col_offset);
// Loop over diagonal
for(unsigned int i = 0; i < N; ++i) {
unsigned int row = i + row_offset;
unsigned int col = i + col_offset;
// Disregard bounds checks with .at()
A.at(row,col) = replace_val;
}
return A;
}
测试:
offdiag_replace(A, 2, 4)
# [,1] [,2] [,3] [,4] [,5]
# [1,] 1 6 4 16 21
# [2,] 2 7 12 4 22
# [3,] 3 8 13 18 4
# [4,] 4 9 14 19 24
# [5,] 5 10 15 20 25
答案 1 :(得分:1)
要将指定对角线上的值提取到向量中,其中k <0表示子对角线,k = 0表示主对角线,而k> 0表示超对角线:
#include<RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::vec diag_get(const arma::mat& X, int k) // note the 'const' and '&'
{
return X.diag(k);
}
要将对角线上的值设置为特定值:
// [[Rcpp::export]]
void diag_fill(arma::mat& X, int k, double value) // note the '&' character
{
X.diag(k).fill(value);
}
要在对角线上将特定值的实例更改为另一个值:
// [[Rcpp::export]]
void diag_change(arma::mat& X, int k, double old_value, double new_value)
{
X.diag(k).replace(old_value, new_value);
}