使用`dplyr`避免`for`循环:计算到观测值的距离

时间:2018-12-11 18:36:49

标签: r dplyr

我有两个数据集AB,对于A中的每个观测值,我想计算一个距离distance(例如,欧氏距离,L1距离或B中的每个观测值(距离的计算均基于数据集中的变量)。来自A的观测值应该与B中的观测值相关,该距离最小。

例如,如果A有5000个观测值,而B有10000个观测值,则

for(i in 1:5000)
{
     x = data.frame(x = numeric(), y = numeric())

     for(j in 1:10000)
     {
         x[j,] = distance(A[i,], B[j,])
     }

     A[i,]$associated_row_B = x[which.min(x[1,]),1]
}

基本上是我想要的(如果观察距离相同,我仍然必须解决)。但是由于我使用的是dplyr,所以我几乎不需要使用for循环。我的解决方案甚至需要两个循环,所以我想知道是否有可能使用dplyr / tidyverse中的解决方案来避免for循环。

一个非常基本的例子:

A:

i           a b
1 -0.5920377 a
2  0.4263199 b
3  0.6737029 a
4  1.3063658 c
5  0.1314103 d

B:

i           a b
1 -0.30201541 a
2 -0.07093386 b
3  0.96317764 c
4 -0.33303061 d
5 -1.00834895 d

和距离函数:

distance = function(x,y) return(c((x[2] - y[2])^2 + abs(x[3] - y[3]), y[1])

返回值的第一个元素是实际距离,第二个值是来自B的标识符。

1 个答案:

答案 0 :(得分:0)

警告:对于大型数据集,这将是非常低效的!

您可以使用crossing中的tidyrslice中的dplyr来完成此操作。

首先,让我们创建两个虚拟数据帧A_dfB_df

A_df <- data.frame(
  observation_A = runif(100),
  id_A = 1:100
)

B_df <- data.frame(
  observation_B = runif(50),
  id_B = 1:50
)

为清楚起见,我在A_dfB_df之间保留了唯一的列名。接下来,我们将使用tidyr::crossing查找两个数据帧之间的行的每种组合。接下来,我们使用mutate来计算距离(在这里,我任意取了它们的差的绝对值,但是您可以在此处应用自定义距离函数)。最后,我们将id_A分组,并使用slice(和基数R which.max)仅保留最小值。

library(tidyverse)


full_df <- A_df %>% 
  crossing(B_df) %>% 
  mutate(distance = abs(observation_A-observation_B)) %>% 
  group_by(id_A) %>% 
  slice(which.min(distance))

看看full_df,我们得到了希望的结果:

> full_df
# A tibble: 100 x 5
# Groups:   id_A [100]
   observation_A  id_A observation_B  id_B distance
           <dbl> <int>         <dbl> <int>    <dbl>
 1         0.826     1         0.851    44  0.0251 
 2         0.903     2         0.905     3  0.00176
 3         0.371     3         0.368    18  0.00305
 4         0.554     4         0.577    34  0.0232 
 5         0.656     5         0.654    10  0.00268
 6         0.120     6         0.110    37  0.0101 
 7         0.991     7         0.988     6  0.00244
 8         0.983     8         0.988     6  0.00483
 9         0.325     9         0.318    45  0.00649
10         0.860    10         0.864    40  0.00407
# ... with 90 more rows