如何基于多个变量在R中创建平衡集

时间:2017-10-05 16:47:00

标签: r grouping subset balanced-groups

我有一个大型数据集,我需要分成多个平衡集。

该集合如下所示:

> data<-matrix(runif(4000, min=0, max=10), nrow=500, ncol=8 )
> colnames(data)<-c("A","B","C","D","E","F","G","H")

每个包含例如20行的集合将需要在多个变量之间进行平衡,以使每个子集最终具有与所有其他子集相比包含在其子组中的B,C,D的相似均值。

有没有办法用R做到这一点?任何建议将不胜感激。先感谢您!

2 个答案:

答案 0 :(得分:0)

library(tidyverse)

# Reproducible data
set.seed(2)
data<-matrix(runif(4000, min=0, max=10), nrow=500, ncol=8 )
colnames(data)<-c("A","B","C","D","E","F","G","H")

data=as.data.frame(data)

更新了答案

如果您希望将给定行的观察值保持在一起,则可能无法在每列中的集合之间获得类似的方法。使用8列(如示例数据中所示),您需要25个20行集,其中每列A集具有相同的均值,每列B集具有相同的均值,等等。限制。但是,可能有一些算法可以找到最小化集合均值的集合的成员资格分配计划。

但是,如果您可以分别从每列中获取20个观察结果而不考虑它来自哪一行,那么这里有一个选项:

# Group into sets with same means
same_means = data %>% 
  gather(key, value) %>% 
  arrange(value) %>% 
  group_by(key) %>% 
  mutate(set = c(rep(1:25, 10), rep(25:1, 10)))

# Check means by set for each column
same_means %>% 
  group_by(key, set) %>% 
  summarise(mean=mean(value)) %>% 
  spread(key, mean) %>% as.data.frame
   set        A        B        C        D        E        F        G        H
1    1 4.940018 5.018584 5.117592 4.931069 5.016401 5.171896 4.886093 5.047926
2    2 4.946496 5.018578 5.124084 4.936461 5.017041 5.172817 4.887383 5.048850
3    3 4.947443 5.021511 5.125649 4.929010 5.015181 5.173983 4.880492 5.044192
4    4 4.948340 5.014958 5.126480 4.922940 5.007478 5.175898 4.878876 5.042789
5    5 4.943010 5.018506 5.123188 4.924283 5.019847 5.174981 4.869466 5.046532
6    6 4.942808 5.019945 5.123633 4.924036 5.019279 5.186053 4.870271 5.044757
7    7 4.945312 5.022991 5.120904 4.919835 5.019173 5.187910 4.869666 5.041317
8    8 4.947457 5.024992 5.125821 4.915033 5.016782 5.187996 4.867533 5.043262
9    9 4.936680 5.020040 5.128815 4.917770 5.022527 5.180950 4.864416 5.043587
10  10 4.943435 5.022840 5.122607 4.921102 5.018274 5.183719 4.872688 5.036263
11  11 4.942015 5.024077 5.121594 4.921965 5.015766 5.185075 4.880304 5.045362
12  12 4.944416 5.024906 5.119663 4.925396 5.023136 5.183449 4.887840 5.044733
13  13 4.946751 5.020960 5.127302 4.923513 5.014100 5.186527 4.889140 5.048425
14  14 4.949517 5.011549 5.127794 4.925720 5.006624 5.188227 4.882128 5.055608
15  15 4.943008 5.013135 5.130486 4.930377 5.002825 5.194421 4.884593 5.051968
16  16 4.939554 5.021875 5.129392 4.930384 5.005527 5.197746 4.883358 5.052474
17  17 4.935909 5.019139 5.131258 4.922536 5.003273 5.204442 4.884018 5.059162
18  18 4.935830 5.022633 5.129389 4.927106 5.008391 5.210277 4.877859 5.054829
19  19 4.936171 5.025452 5.127276 4.927904 5.007995 5.206972 4.873620 5.054192
20  20 4.942925 5.018719 5.127394 4.929643 5.005699 5.202787 4.869454 5.055665
21  21 4.941351 5.014454 5.125727 4.932884 5.008633 5.205170 4.870352 5.047728
22  22 4.933846 5.019311 5.130156 4.923804 5.012874 5.213346 4.874263 5.056290
23  23 4.928815 5.021575 5.139077 4.923665 5.017180 5.211699 4.876333 5.056836
24  24 4.928739 5.024419 5.140386 4.925559 5.012995 5.214019 4.880025 5.055182
25  25 4.929357 5.025198 5.134391 4.930061 5.008571 5.217005 4.885442 5.062630

原始答案

# Randomly group data into 20-row groups
set.seed(104)
data = data %>% 
  mutate(set = sample(rep(1:(500/20), each=20)))

head(data)
         A        B         C        D        E         F        G          H set
1 1.848823 6.920055 3.2283369 6.633721 6.794640 2.0288792 1.984295 2.09812642  10
2 7.023740 5.599569 0.4468325 5.198884 6.572196 0.9269249 9.700118 4.58840437  20
3 5.733263 3.426912 7.3168797 3.317611 8.301268 1.4466065 5.280740 0.09172101  19
4 1.680519 2.344975 4.9242313 6.163171 4.651894 2.2253335 1.175535 2.51299726  25
5 9.438393 4.296028 2.3563249 5.814513 1.717668 0.8130327 9.430833 0.68269106  19
6 9.434750 7.367007 1.2603451 5.952936 3.337172 5.2892300 5.139007 6.52763327   5
# Mean by set for each column
data %>% group_by(set) %>% 
  summarise_all(mean)
     set        A        B        C        D        E        F        G        H
 1     1 5.240236 6.143941 4.638874 5.367626 4.982008 4.200123 5.521844 5.083868
 2     2 5.520983 5.257147 5.209941 4.504766 4.231175 3.642897 5.578811 6.439491
 3     3 5.943011 3.556500 5.366094 4.583440 4.932206 4.725007 5.579103 5.420547
 4     4 4.729387 4.755320 5.582982 4.763171 5.217154 5.224971 4.972047 3.892672
 5     5 4.824812 4.527623 5.055745 4.556010 4.816255 4.426381 3.520427 6.398151
 6     6 4.957994 7.517130 6.727288 4.757732 4.575019 6.220071 5.219651 5.130648
 7     7 5.344701 4.650095 5.736826 5.161822 5.208502 5.645190 4.266679 4.243660
 8     8 4.003065 4.578335 5.797876 4.968013 5.130712 6.192811 4.282839 5.669198
 9     9 4.766465 4.395451 5.485031 4.577186 5.366829 5.653012 4.550389 4.367806
10    10 4.695404 5.295599 5.123817 5.358232 5.439788 5.643931 5.127332 5.089670
# ... with 15 more rows

如果数据框中的总行数不能被每个集合中所需的行数整除,则可以在创建集合时执行以下操作:

data = data %>% 
  mutate(set = sample(rep(1:ceiling(500/20), each=20))[1:n()])

在这种情况下,设置的大小会有所不同,数据行的数量不能被每组中所需的行数整除。

答案 1 :(得分:0)

对于处于类似位置的人来说,以下方法可能值得一试。

它基于 groupdata2fold() 函数中的数值平衡,它允许为单个列创建具有平衡均值的组。通过对每一列进行标准化并在数值上平衡它们的行列总和,我们可能会增加在各个列中获得平衡均值的机会。

我将这种方法与随机创建组数次并选择均值方差最小的分组进行了比较。这似乎好一点,但我不太相信这将适用于所有情况。

# Attach dplyr and groupdata2
library(dplyr)
library(groupdata2)

set.seed(1)

# Create the dataset
data <- matrix(runif(4000, min = 0, max = 10), nrow = 500, ncol = 8)
colnames(data) <- c("A", "B", "C", "D", "E", "F", "G", "H")
data <- dplyr::as_tibble(data)

# Standardize all columns and calculate row sums
data_std <- data %>% 
  dplyr::mutate_all(.funs = function(x){(x-mean(x))/sd(x)}) %>% 
  dplyr::mutate(total = rowSums(across(where(is.numeric))))

# Create groups (new column called ".folds")
# We numerically balance the "total" column 
data_std <- data_std %>% 
  groupdata2::fold(k = 25, num_col = "total")  # k = 500/20=25

# Transfer the groups to the original (non-standardized) data frame
data$group <- data_std$.folds

# Check the means
data %>% 
  dplyr::group_by(group) %>% 
  dplyr::summarise_all(.funs = mean)

> # A tibble: 25 x 9
>    group     A     B     C     D     E     F     G     H
>    <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
>  1 1      4.48  5.05  4.80  5.65  5.04  4.60  5.12  4.85
>  2 2      5.57  5.17  3.21  5.46  4.46  5.89  5.06  4.79
>  3 3      4.33  6.02  4.57  6.18  4.76  3.79  5.94  3.71
>  4 4      4.51  4.62  4.62  5.27  4.65  5.41  5.26  5.23
>  5 5      4.55  5.10  4.19  5.41  5.28  5.39  5.57  4.23
>  6 6      4.82  4.74  6.10  4.34  4.82  5.08  4.89  4.81
>  7 7      5.88  4.49  4.13  3.91  5.62  4.75  5.46  5.26
>  8 8      4.11  5.50  5.61  4.23  5.30  4.60  4.96  5.35
>  9 9      4.30  3.74  6.45  5.60  3.56  4.92  5.57  5.32
> 10 10     5.26  5.50  4.35  5.29  4.53  4.75  4.49  5.45
> # … with 15 more rows

# Check the standard deviations of the means
# Could be used to compare methods
data %>% 
  dplyr::group_by(group) %>% 
  dplyr::summarise_all(.funs = mean) %>% 
  dplyr::summarise(across(where(is.numeric), sd))

> # A tibble: 1 x 8
>       A     B     C     D     E     F     G     H
>   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
> 1 0.496 0.546 0.764 0.669 0.591 0.611 0.690 0.475

不过,最好在标准化数据上比较不同方法的均值和均值方差(或上述标准差)。在这种情况下,可以计算方差之和并将其最小化。

data_std %>% 
  dplyr::select(-total) %>% 
  dplyr::group_by(.folds) %>% 
  dplyr::summarise_all(.funs = mean) %>% 
  dplyr::summarise(across(where(is.numeric), sd)) %>% 
  sum()

> 1.643989

比较多个平衡拆分

fold() 函数允许一次创建多个唯一 分组因子(拆分)。因此,在这里,我将执行 20 次数值平衡拆分,并找到均值标准差总和最低的分组。我会进一步将其转换为函数。

create_multi_balanced_groups <- function(data, cols, k, num_tries){
  
  # Extract the variables of interest
  # We assume these are numeric but we could add a check
  data_to_balance <- data[, cols]
  
  # Standardize all columns
  # And calculate rowwise sums
  data_std <- data_to_balance %>% 
    dplyr::mutate_all(.funs = function(x){(x-mean(x))/sd(x)}) %>% 
    dplyr::mutate(total = rowSums(across(where(is.numeric))))
  
  # Create `num_tries` unique numerically balanced splits
  data_std <- data_std %>% 
    groupdata2::fold(
      k = k, 
      num_fold_cols = num_tries,
      num_col = "total"
    )
  
  # The new fold column names ".folds_1", ".folds_2", etc.
  fold_col_names <- paste0(".folds_", seq_len(num_tries))
  
  # Remove total column
  data_std <- data_std %>% 
    dplyr::select(-total)
  
  # Calculate score for each split
  # This could probably be done more efficiently without a for loop
  variance_scores <- c()
  for (fcol in fold_col_names){
    score <- data_std %>% 
      dplyr::group_by(!!as.name(fcol)) %>% 
      dplyr::summarise(across(where(is.numeric), mean)) %>% 
      dplyr::summarise(across(where(is.numeric), sd)) %>% 
      sum()
    
    variance_scores <- append(variance_scores, score)
  }
  
  # Get the fold column with the lowest score
  lowest_fcol_index <- which.min(variance_scores)
  best_fcol <- fold_col_names[[lowest_fcol_index]]
  
  # Add the best fold column / grouping factor to the original data
  data[["group"]] <- data_std[[best_fcol]]
  
  # Return the original data and the score of the best fold column
  list(data, min(variance_scores))
  
}

# Run with 20 splits
set.seed(1)
data_grouped_and_score <- create_multi_balanced_groups(
  data = data,
  cols = c("A", "B", "C", "D", "E", "F", "G", "H"),
  k = 25,
  num_tries = 20
)

# Check data
data_grouped_and_score[[1]]

> # A tibble: 500 x 9
>         A     B     C     D     E      F     G     H group
>     <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl> <fct>
>  1 5.86   6.54  0.500 2.88  5.70  9.67    2.29 3.01  2    
>  2 0.0895 4.69  5.71  0.343 8.95  7.73    5.76 9.58  1    
>  3 2.94   1.78  2.06  6.66  9.54  0.600   4.26 0.771 16   
>  4 2.77   1.52  0.723 8.11  8.95  1.37    6.32 6.24  7    
>  5 8.14   2.49  0.467 8.51  0.889 6.28    4.47 8.63  13   
>  6 2.60   8.23  9.17  5.14  2.85  8.54    8.94 0.619 23   
>  7 7.24   0.260 6.64  8.35  8.59  0.0862  1.73 8.10  5    
>  8 9.06   1.11  6.01  5.35  2.01  9.37    7.47 1.01  1    
>  9 9.49   5.48  3.64  1.94  3.24  2.49    3.63 5.52  7    
> 10 0.731  0.230 5.29  8.43  5.40  8.50    3.46 1.23  10   
> # … with 490 more rows

# Check score
data_grouped_and_score[[2]]

> 1.552656

通过注释掉 num_col = "total" 行,我们可以在没有数字平衡的情况下运行它。对我来说,这给出了 1.615257 的分数。

免责声明:我是 groupdata2 包的作者。 fold() 函数还可以平衡分类列 (cat_col) 并将具有相同 ID 的所有数据点保留在同一折叠中 (id_col)(例如避免交叉验证中的泄漏) .还有一个非常相似的 partition() 函数。