假设您在给定的实验中有一些参与者和控制者,并通过三个特征对其进行了评估,如下所示:
part_A <- c(3, 5, 4)
part_B <- c(12, 15, 18)
part_C <- c(50, 40, 45)
ctrl_1 <- c(4, 5, 5)
ctrl_2 <- c(1, 0, 4)
ctrl_3 <- c(13, 16, 17)
ctrl_4 <- c(28, 30, 35)
ctrl_5 <- c(51, 43, 44)
我想为每个参与者查找哪个控制案例是最接近的匹配项。
如果我使用dist()
函数,我可以得到它,但是计算控件之间的距离也要花费很多时间,这对我来说是没有用的(在真实数据中,有1000次控制案例比参与案例更多。
有没有一种方法可以查询这些这些元素中的每个元素与每个这些元素之间的距离?对于大型数据集有什么用?
在上面的示例中,我想要的结果是:
Participant Closest_Ctrl
1 part_A ctrl_1
2 part_B ctrl_3
3 part_C ctrl_5
答案 0 :(得分:3)
这是一种解决方案,对于不太多的参与者应该足够快:
ctrl <- do.call(cbind, mget(ls(pattern = "ctrl_\\d+")))
dat <- mget(ls(pattern = "part_[[:upper:]+]"))
res <- vapply(dat, function(x) colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))],
FUN.VALUE = character(1))
stack(res)
# values ind
#1 ctrl_1 part_A
#2 ctrl_3 part_B
#3 ctrl_5 part_C
如果这太慢了,我会很快用Rcpp对其进行编码。
答案 1 :(得分:1)
将输入转换为数据帧
parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
生成输出
# calculate distances
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts),
Closest_Ctrl = names(ctrl)[max.col(-dists)])
# Participant Closest_Ctrl
# 1 part_A ctrl_1
# 2 part_B ctrl_3
# 3 part_C ctrl_5
基准
parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(100, parts, simplify = F))
ctrl <- do.call(cbind, replicate(100, ctrl, simplify = F))
r1 <- f1()
r2 <- f2()
all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2),
r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE
f1 <- function(x){
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts),
Closest_Ctrl = names(ctrl)[max.col(-dists)])
}
f2 <- function(x){
res <- vapply(parts, function(x) colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))],
FUN.VALUE = character(1))
stack(res)
}
microbenchmark::microbenchmark(f1(), f2(), times = 5)
# Unit: milliseconds
# expr min lq mean median uq max neval
# f1() 305.7324 314.8356 435.3961 324.6116 461.4788 770.3221 5
# f2() 12359.6995 12831.7995 13567.8296 13616.5216 14244.0836 14787.0438 5
基准2
parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(10, parts, simplify = F))
ctrl <- do.call(cbind, replicate(10*1000, ctrl, simplify = F))
r1 <- f1()
r2 <- f2()
all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2),
r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE
f1 <- function(x){
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts),
Closest_Ctrl = names(ctrl)[max.col(-dists)])
}
f2 <- function(x){
res <- vapply(parts, function(x) colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))],
FUN.VALUE = character(1))
stack(res)
}
microbenchmark::microbenchmark(f1(), f2(), times = 5)
# Unit: seconds
# expr min lq mean median uq max neval
# f1() 3.450176 4.211997 4.493805 4.339818 5.154191 5.312844 5
# f2() 119.120484 124.280423 132.637003 130.858727 131.148630 157.776749 5