在R

时间:2017-02-09 01:59:42

标签: r

假设我有矩阵M,我只想保留这些矩阵中每行的2个最高值,其他的将设置为零。

M <- rbind(c(0.1, 0.6, 0.2, 0.3, 0.7), c(0.8, 0.1, 0.7, 0.2, 0.4))

> M
     [,1] [,2] [,3] [,4] [,5]
[1,]  0.1  0.6  0.2  0.3  0.7
[2,]  0.8  0.1  0.7  0.2  0.4

我想要这个结果。

rbind(c(0, 0.6, 0, 0, 0.7), c(0.8, 0, 0.7, 0, 0))


>  rbind(c(0, 0.6, 0, 0, 0.7), c(0.8, 0, 0.7, 0, 0))
     [,1] [,2] [,3] [,4] [,5]
[1,]  0.0  0.6  0.0    0  0.7
[2,]  0.8  0.0  0.7    0  0.0

我理解apply(M, 1, sort)可以做到这一点,但如果矩阵M很大,那么它会很慢,那么最快的方法是什么?

感谢。

4 个答案:

答案 0 :(得分:3)

我建议使用data.table,这涉及一些重塑,但应该很快。如果你可以在最后完成最后的重塑步骤而离开,那么它也应该节省一些时间。

library(data.table)

dt <- as.data.table(M)

## define a 'grouping variable', which in this case is just the row number
## this lets us keep track of the row of the matrix
dt[, grp := .I]

## melt into long form
dt <- melt(dt, id.vars = "grp")

## order the data by the value, for each group, and select the top 2 rows
dt_max <- dt[ dt[ order(-value), .I[c(1,2)], by = .(grp)]$V1 ]

## set all the original values to 0
dt[, value := 0]

## then overwrite those 0s with the 'top 2' values in dt_max
dt[ dt_max, on = c("grp", "variable"), value := i.value]

as.matrix(dcast(dt, formula = grp ~ variable))
     grp  V1  V2  V3 V4  V5
[1,]   1 0.0 0.6 0.0  0 0.7
[2,]   2 0.8 0.0 0.7  0 0.0

答案 1 :(得分:3)

使用pmax的方法:

m <- M
x1 <- do.call(pmax, lapply(1:ncol(M), function(x) M[, x]))
m[m == x1]  <- NA
x2 <- do.call(pmax, c(lapply(1:ncol(M), function(x) m[, x]), na.rm = T))
M[M != x1 & M != x2] <- 0
M  

一些时间安排。设置一个大矩阵,然后运行其他几个提议的方法:

set.seed(1234)
M <- matrix(floor(rnorm(1e7, 100, 10)), nc = 10)
f1 <- function(M) {
  m <- M
  x1 <- do.call(pmax, lapply(1:ncol(M), function(x) M[, x]))
  m[m == x1]  <- NA
  x2 <- do.call(pmax, c(lapply(1:ncol(M), function(x) m[, x]), na.rm = T))
  M[M != x1 & M != x2] <- 0
  M  
}

f2 <- function(M) {
  dt <- as.data.table(M)
  dt[, grp := 1:.N]
  dt <- melt(dt, id.vars = "grp")
  dt_max <- dt[ dt[ order(-value), .I[c(1,2)], by = .(grp)]$V1 ]
  dt[, value := 0]
  dt[ dt_max, on = c("grp", "variable"), value := i.value]
  as.matrix(dcast(dt, formula = grp ~ variable))  
}

f3 <- function(M) {
  tmp <- data.frame(row=c(row(M)), val=c(M), seq=seq_along(M))
  tmp <- tmp[do.call(order,c(tmp[1:2],decreasing=TRUE)),]
  M[tmp$seq] <- with(tmp, ave(val,row,FUN=function(x) replace(x, -(1:2), 0) ))
  M
}

使用microbenchmark进行基准测试,如@SymbolixAU所建议的那样:

microbenchmark::microbenchmark(
  f1 = { f1(M) }, 
  f2 = { f2(M) }, 
  f3 = { f3(M) },
  times = 10L)
# Unit: milliseconds
#  expr        min         lq      mean    median        uq       max neval cld
#    f1   926.9069   946.6892  1084.038  1009.497  1082.454  1476.972    10 a  
#    f2  6315.3971  6750.1864  7327.610  7237.323  7785.078  9198.780    10  b 
#    f3 13076.0617 13435.9920 15360.451 15118.323 16497.295 19792.398    10   c

此外,如果给定行的最大两个数字有重复,则其他方法似乎将重复项设置为零。

答案 2 :(得分:2)

某些逻辑为@SymbolixAU,但使用基本R函数:

tmp <- data.frame(row=c(row(M)), val=c(M), seq=seq_along(M))
tmp <- tmp[do.call(order,c(tmp[1:2],decreasing=TRUE)),]
M[tmp$seq] <- with(tmp, ave(val,row,FUN=function(x) replace(x, -(1:2), 0) ))
M

#     [,1] [,2] [,3] [,4] [,5]
#[1,]  0.0  0.6  0.0    0  0.7
#[2,]  0.8  0.0  0.7    0  0.0

答案 3 :(得分:0)

mx1 = max(M[1,])
wh1 = which(M[1,]==mx,arr.ind=TRUE)
mx2 = max(M[1,-wh1])
wh2 = which(M[1,-wh1]==mx2,arr.ind=TRUE)

然后在一些新分配的零数组中将这些给定值分配给这些给定的索引。