我通常可以通过一点思考弄清楚如何进行矢量化,但是尽管阅读了一堆StackOverflow q& a,我仍然难倒! 我想用适当的应用函数替换这些嵌套的for循环,但是如果我错过了整个问题的明显不同方法,请随时告诉我!
在测试的上下文中考虑这个例子,其中第一行是关键,后面的每一行都是学生'答案。作为输出,我想要一个数组,每个正确答案为1,每个错误答案为0。 for循环可以工作,但是当你扩展到数千行和列时非常慢。
这是我可以重现的例子,并提前感谢您的帮助!
#build sample data
dat <- array(dim=c(9,6))
for (n in 1:9){
dat[n,1:6] <- c(paste("ID00",n,sep=""),
sample(c("A","B","C","D"), size=5, replace=TRUE))}
dat[3,4]<-NA
key<-c("key","A","B","B","C","D")
dat <- rbind(key,dat)
>dat
[,1] [,2] [,3] [,4] [,5] [,6]
"key" "A" "B" "B" "C" "D"
"ID001" "B" "A" "D" "B" "C"
"ID002" "C" "C" "C" "B" "B"
"ID003" "A" "C" NA "D" "D"
"ID004" "D" "B" "D" "A" "A"
"ID005" "A" "C" "A" "C" "A"
"ID006" "D" "D" "B" "B" "A"
"ID007" "B" "D" "A" "D" "A"
"ID008" "D" "D" "B" "D" "A"
"ID009" "D" "C" "B" "D" "D"
#score file
dat2 <- array(dim=c(9,5))
for (row in 2:10){
for (column in 2:6){
if (is.na(dat[row,column])){
p <- NA
}else if (dat[row,column]==dat[1,column]){
p <- 1
}else p <- 0
dat2[row-1,column-1]<-p
}
}
> dat2
[,1] [,2] [,3] [,4] [,5]
[1,] 0 0 0 0 0
[2,] 0 0 0 0 0
[3,] 1 0 NA 0 1
[4,] 0 1 0 0 0
[5,] 1 0 0 1 0
[6,] 0 0 1 0 0
[7,] 0 0 0 0 0
[8,] 0 0 1 0 0
[9,] 0 0 1 0 1
答案 0 :(得分:1)
为可重复性设置种子:
set.seed(1)
dat <- array(dim=c(9,6))
for (n in 1:9){
dat[n,1:6] <- c(paste("ID00",n,sep=""),
sample(c("A","B","C","D"), size=5, replace=TRUE))}
dat[3,4]<-NA
key<-c("key","A","B","B","C","D")
dat <- rbind(key,dat)
这将完成这项工作:
key <- rep(dat[1, -1], each = nrow(dat) - 1L) ## expand "key" row
dummy <- (dat[-1, -1] == key) + 0L ## vectorized / element-wise "=="
基本上我们想要一个矢量化的"=="
。但我们需要先将dat[1,-1]
展开到dat[-1,-1]
的同一维度。最后+ 0L
强制TRUE / FALSE
矩阵到1 / 0
矩阵。
# [,1] [,2] [,3] [,4] [,5]
# 0 1 0 0 0
# 0 0 0 1 0
# 1 0 NA 0 1
# 0 0 0 0 1
# 0 0 0 0 0
# 0 0 1 0 0
# 0 0 1 0 1
# 0 0 0 1 0
# 0 0 0 1 0
我还没有查看Gregor的基准脚本。但这是我的。
set.seed(1)
dat <- matrix(sample(LETTERS[4], 1000 * 1000, TRUE), 1000)
key <- sample(LETTERS[1:4], 1000, TRUE)
microbenchmark(rep(key, each = 1000) == dat, t(t(dat) == key))
#Unit: milliseconds
# expr min lq mean median uq
# rep(key, each = 1000) == dat 32.16888 34.01138 42.61639 35.57526 40.27944
# t(t(dat) == key) 50.93348 52.96008 63.74475 56.04706 60.38750
# max neval cld
# 81.96044 100 a
# 106.54916 100 b
我的方法和Gregor的唯一区别是rep(, each)
扩展vs.s. rep_len
扩张。两次扩展都需要相同的内存量,扩展后,"=="
以列方式完成。我预测额外的开销将由两个t()
引起,基准测试结果似乎是合理的。希望结果不依赖于平台。
答案 1 :(得分:1)
这与哲源的回答基本相同(依靠向量化set.seed(1)
然后强制回数字),我只是先调换矩阵而不是扩展密钥。
由于矩阵是按列而不是行存储/操作的,如果键是一列,每个学生也是一个列向量,那么回收就可以了。
在生成数据之前使用key = dat[1, -1]
tdat = t(dat[-1, -1])
t((tdat == key) + 0L)
# [,1] [,2] [,3] [,4] [,5]
# 0 1 0 0 0
# 0 0 0 1 0
# 1 0 NA 0 1
# 0 0 0 0 1
# 0 0 0 0 0
# 0 0 1 0 0
# 0 0 1 0 1
# 0 0 0 1 0
# 0 0 0 1 0
...
'key'
如果您改为将第一列更改为行名,则可以轻松保留它们,而不会将学生ID标记为不正确,因为它们不是row.names(dat) = dat[, 1]
dat = dat[, -1]
key = dat[1, ]
tdat = t(dat[-1, ])
result = t((tdat == key) + 0)
result
# [,1] [,2] [,3] [,4] [,5]
# ID001 0 1 0 0 0
# ID002 0 0 0 1 0
# ID003 1 0 NA 0 1
# ID004 0 0 0 0 1
# ID005 0 0 0 0 0
# ID006 0 0 1 0 0
# ID007 0 0 1 0 1
# ID008 0 0 0 1 0
# ID009 0 0 0 1 0
rowSums(result)
# ID001 ID002 ID003 ID004 ID005 ID006 ID007 ID008 ID009
# 1 1 NA 1 0 1 2 1 1
。这样就可以更好地总结事情:
gregor = function(key, dat) {
t(t(dat) == key)
}
zheyuan = function(key, dat) {
dat == rep(key, each = nrow(dat))
}
library(microbenchmark)
nr = 10000
nc = 1000
key = sample(1:10, nc, replace = T)
dat = matrix(sample(1:10, nr * nc, replace = T), nrow = nr)
print(microbenchmark(gregor(key, dat), zheyuan(key, dat)), signif = 4)
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# gregor(key, dat) 104.5 113.2 135.5970 128.2 144.5 336.2 100 a
# zheyuan(key, dat) 196.0 202.8 215.7822 207.0 224.9 394.4 100 b
identical(gregor(key, dat), zheyan(key, dat))
# [1] TRUE
简化输入并对中等大小的数据运行基准测试,两者都非常快。双转置速度要快一些。
D:\Android_Dev\Android_sdk\platform-tools>adb shell dumpsys cpuinfo
Load: 4.03 / 3.43 / 2.44
CPU usage from 23770ms to 16630ms ago:
58% 1844/logd: 58% user + 0% kernel / faults: 3 minor
50% 3895/com.google.android.wearable.app:ui: 41% user + 9.3% kernel / faults: 1798 minor
26% 1864/adbd: 2.8% user + 23% kernel / faults: 1243 minor
22% 4880/logcat: 7.8% user + 15% kernel
9.7% 7834/kworker/0:2: 0% user + 9.7% kernel
4.9% 2198/system_server: 2.6% user + 2.2% kernel / faults: 76 minor
答案 2 :(得分:0)
如果您希望将其放在没有for
或apply
的一行中,请尝试
dat2 <- matrix(as.numeric(dat==rep(dat[1,],each=nrow(dat))),nrow=nrow(dat))[-1,-1]