假设我有矩阵M
,我只想保留这些矩阵中每行的2个最高值,其他的将设置为零。
M <- rbind(c(0.1, 0.6, 0.2, 0.3, 0.7), c(0.8, 0.1, 0.7, 0.2, 0.4))
> M
[,1] [,2] [,3] [,4] [,5]
[1,] 0.1 0.6 0.2 0.3 0.7
[2,] 0.8 0.1 0.7 0.2 0.4
我想要这个结果。
rbind(c(0, 0.6, 0, 0, 0.7), c(0.8, 0, 0.7, 0, 0))
> rbind(c(0, 0.6, 0, 0, 0.7), c(0.8, 0, 0.7, 0, 0))
[,1] [,2] [,3] [,4] [,5]
[1,] 0.0 0.6 0.0 0 0.7
[2,] 0.8 0.0 0.7 0 0.0
我理解apply(M, 1, sort)
可以做到这一点,但如果矩阵M
很大,那么它会很慢,那么最快的方法是什么?
感谢。
答案 0 :(得分:3)
我建议使用data.table
,这涉及一些重塑,但应该很快。如果你可以在最后完成最后的重塑步骤而离开,那么它也应该节省一些时间。
library(data.table)
dt <- as.data.table(M)
## define a 'grouping variable', which in this case is just the row number
## this lets us keep track of the row of the matrix
dt[, grp := .I]
## melt into long form
dt <- melt(dt, id.vars = "grp")
## order the data by the value, for each group, and select the top 2 rows
dt_max <- dt[ dt[ order(-value), .I[c(1,2)], by = .(grp)]$V1 ]
## set all the original values to 0
dt[, value := 0]
## then overwrite those 0s with the 'top 2' values in dt_max
dt[ dt_max, on = c("grp", "variable"), value := i.value]
as.matrix(dcast(dt, formula = grp ~ variable))
grp V1 V2 V3 V4 V5
[1,] 1 0.0 0.6 0.0 0 0.7
[2,] 2 0.8 0.0 0.7 0 0.0
答案 1 :(得分:3)
使用pmax
的方法:
m <- M
x1 <- do.call(pmax, lapply(1:ncol(M), function(x) M[, x]))
m[m == x1] <- NA
x2 <- do.call(pmax, c(lapply(1:ncol(M), function(x) m[, x]), na.rm = T))
M[M != x1 & M != x2] <- 0
M
一些时间安排。设置一个大矩阵,然后运行其他几个提议的方法:
set.seed(1234)
M <- matrix(floor(rnorm(1e7, 100, 10)), nc = 10)
f1 <- function(M) {
m <- M
x1 <- do.call(pmax, lapply(1:ncol(M), function(x) M[, x]))
m[m == x1] <- NA
x2 <- do.call(pmax, c(lapply(1:ncol(M), function(x) m[, x]), na.rm = T))
M[M != x1 & M != x2] <- 0
M
}
f2 <- function(M) {
dt <- as.data.table(M)
dt[, grp := 1:.N]
dt <- melt(dt, id.vars = "grp")
dt_max <- dt[ dt[ order(-value), .I[c(1,2)], by = .(grp)]$V1 ]
dt[, value := 0]
dt[ dt_max, on = c("grp", "variable"), value := i.value]
as.matrix(dcast(dt, formula = grp ~ variable))
}
f3 <- function(M) {
tmp <- data.frame(row=c(row(M)), val=c(M), seq=seq_along(M))
tmp <- tmp[do.call(order,c(tmp[1:2],decreasing=TRUE)),]
M[tmp$seq] <- with(tmp, ave(val,row,FUN=function(x) replace(x, -(1:2), 0) ))
M
}
使用microbenchmark
进行基准测试,如@SymbolixAU所建议的那样:
microbenchmark::microbenchmark(
f1 = { f1(M) },
f2 = { f2(M) },
f3 = { f3(M) },
times = 10L)
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# f1 926.9069 946.6892 1084.038 1009.497 1082.454 1476.972 10 a
# f2 6315.3971 6750.1864 7327.610 7237.323 7785.078 9198.780 10 b
# f3 13076.0617 13435.9920 15360.451 15118.323 16497.295 19792.398 10 c
此外,如果给定行的最大两个数字有重复,则其他方法似乎将重复项设置为零。
答案 2 :(得分:2)
某些逻辑为@SymbolixAU,但使用基本R函数:
tmp <- data.frame(row=c(row(M)), val=c(M), seq=seq_along(M))
tmp <- tmp[do.call(order,c(tmp[1:2],decreasing=TRUE)),]
M[tmp$seq] <- with(tmp, ave(val,row,FUN=function(x) replace(x, -(1:2), 0) ))
M
# [,1] [,2] [,3] [,4] [,5]
#[1,] 0.0 0.6 0.0 0 0.7
#[2,] 0.8 0.0 0.7 0 0.0
答案 3 :(得分:0)
mx1 = max(M[1,])
wh1 = which(M[1,]==mx,arr.ind=TRUE)
mx2 = max(M[1,-wh1])
wh2 = which(M[1,-wh1]==mx2,arr.ind=TRUE)
然后在一些新分配的零数组中将这些给定值分配给这些给定的索引。