R dist:计算几对多的距离

时间:2019-03-11 15:11:05

标签: r distance

假设您在给定的实验中有一些参与者和控制者,并通过三个特征对其进行了评估,如下所示:

part_A <- c(3, 5, 4)
part_B <- c(12, 15, 18)
part_C <- c(50, 40, 45)

ctrl_1 <- c(4, 5, 5)
ctrl_2 <- c(1, 0, 4)
ctrl_3 <- c(13, 16, 17)
ctrl_4 <- c(28, 30, 35)
ctrl_5 <- c(51, 43, 44)

我想为每个参与者查找哪个控制案例是最接近的匹配项。

如果我使用dist()函数,我可以得到它,但是计算控件之间的距离也要花费很多时间,这对我来说是没有用的(在真实数据中,有1000次控制案例比参与案例更多。

有没有一种方法可以查询这些这些元素中的每个元素与每个这些元素之间的距离?对于大型数据集有什么用?

在上面的示例中,我想要的结果是:

  Participant Closest_Ctrl
1      part_A       ctrl_1
2      part_B       ctrl_3
3      part_C       ctrl_5

2 个答案:

答案 0 :(得分:3)

这是一种解决方案,对于不太多的参与者应该足够快:

ctrl <- do.call(cbind, mget(ls(pattern = "ctrl_\\d+")))

dat <- mget(ls(pattern = "part_[[:upper:]+]"))

res <- vapply(dat, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
                FUN.VALUE = character(1))

stack(res)
#  values    ind
#1 ctrl_1 part_A
#2 ctrl_3 part_B
#3 ctrl_5 part_C

如果这太慢了,我会很快用Rcpp对其进行编码。

答案 1 :(得分:1)

将输入转换为数据帧

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))

生成输出

# calculate distances
dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))

# generate output by calculating column with min value (max negative value)
data.frame(Participant = names(parts), 
           Closest_Ctrl = names(ctrl)[max.col(-dists)])

#   Participant Closest_Ctrl
# 1      part_A       ctrl_1
# 2      part_B       ctrl_3
# 3      part_C       ctrl_5

基准

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(100, parts, simplify = F))
ctrl <- do.call(cbind, replicate(100, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: milliseconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   305.7324   314.8356   435.3961   324.6116   461.4788   770.3221     5
#  f2() 12359.6995 12831.7995 13567.8296 13616.5216 14244.0836 14787.0438     5

基准2

parts <- do.call(data.frame, mget(ls(pattern = "part_[A-C]")))
ctrl <- do.call(data.frame, mget(ls(pattern = "ctrl_[1-5]")))
parts <- do.call(cbind, replicate(10, parts, simplify = F))
ctrl <- do.call(cbind, replicate(10*1000, ctrl, simplify = F))

r1 <- f1()
r2 <- f2()

all.equal(r1 %>% lapply(as.factor) %>% setNames(1:2), 
          r2[2:1] %>% lapply(as.factor) %>% setNames(1:2))
# [1] TRUE


f1 <- function(x){
  dists <- outer(parts, ctrl, Vectorize(function(x, y) sqrt(sum((x - y)^2))))
  # generate output by calculating column with min value (max negative value)
  data.frame(Participant = names(parts), 
             Closest_Ctrl = names(ctrl)[max.col(-dists)])
}

f2 <- function(x){
  res <- vapply(parts, function(x)  colnames(ctrl)[which.min(sqrt(colSums(x - ctrl)^2))], 
            FUN.VALUE = character(1))

  stack(res)
}

microbenchmark::microbenchmark(f1(), f2(), times = 5)        
# Unit: seconds
#  expr        min         lq       mean     median         uq        max neval
#  f1()   3.450176   4.211997   4.493805   4.339818   5.154191   5.312844     5
#  f2() 119.120484 124.280423 132.637003 130.858727 131.148630 157.776749     5