如何编写一个提取矩阵对角线的函数?

时间:2016-10-12 22:39:08

标签: r

所以我试图在R中编写一个提取矩阵对角线的函数,即像diag(x)一样工作,显然不使用diag(x)。

我不确定从哪里开始。

2 个答案:

答案 0 :(得分:2)

适用于非方形矩阵:

diag2 <- function(x){
    n <- min(dim(x))
    return(x[matrix(rep(1:n, 2), n, 2)])
}

此外,如果你看diag,你可以看到发生了什么:

if (is.matrix(x)) {
    if (nargs() > 1L) 
        stop("'nrow' or 'ncol' cannot be specified when 'x' is a matrix")
    if ((m <- min(dim(x))) == 0L) 
        return(vector(typeof(x), 0L))
    y <- x[1 + 0L:(m - 1L) * (dim(x)[1L] + 1)]
    nms <- dimnames(x)
    if (is.list(nms) && !any(sapply(nms, is.null)) && identical((nm <- nms[[1L]][seq_len(m)]), 
        nms[[2L]][seq_len(m)])) 
        names(y) <- nm
    return(y)
}
## there's more...

如果有人好奇,我尝试了基准测试,包括@ shayaa的方法...

set.seed(101)
a <- matrix(runif(1e6), 1e3, 1e3)

diag3 <- function(x){
    x[row(x) == col(x)]
}
library(microbenchmark)
microbenchmark(diag(a), diag2(a), diag3(a))
## Unit: microseconds
##      expr       min        lq        mean     median         uq        max neval
##   diag(a)    23.205    33.915    47.59246    47.3030    58.2355     79.878   100
##  diag2(a)    31.238    37.262    58.03028    57.5665    70.7300    107.546   100
##  diag3(a) 11744.788 12659.595 15425.79847 13874.7265 15054.8285 164130.271   100

答案 1 :(得分:2)

set.seed(1111)
df <- matrix(rnorm(16),4,4)
df[row(df)==col(df)]
如果你的矩阵不是正方形

也有效

set.seed(1111); df <- matrix(rnorm(30),5,6)
df[row(df)==col(df)]