使用c / c ++或矢量化加速条件R循环

时间:2015-07-22 01:39:04

标签: r data.table dplyr rcpp

我需要根据数据集中的其他变量创建一个包含协变量值的矩阵。这是我当前解决方案的自包含示例

library(dplyr)
library(survival)
library(microbenchmark)
data(heart, package = "survival")
data <- heart

# Number of unique subjects
n.sub <- data %>% group_by(id) %>% n_groups()
# Unique failure times
fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
# Number of unique failure times
n.fail.time <- length(fail.time)

# Pre-fill matrix. Will be filled with covariate values.
mat <- matrix(NA_real_, nrow = n.sub, ncol = n.fail.time)
# Run loop
for(i in 1:n.sub) { # Number of subjects
  data.subject <- data[data$id == i, ] # subsetting here provides nice speed-up
  for(j in 1:n.fail.time) { # Number of failure times.
    value <- subset(data.subject, (start < fail.time[j]) & (stop >= fail.time[j]), select = transplant, drop = TRUE)
    if(length(value) == 0) { # An early event or censor will return empty value. Assign to zero.
      mat[i, j] <- 0
    }
    else {
      mat[i, j] <- value # True value
    }
  }
}

对于具有数千次观察的数据集来说,这太慢了。我不知道如何使用R代码最好地对其进行矢量化,并且我不太了解使用Rcpp的c / c ++。如何使用其中一个(或其他)选项来加速这个例子?

看起来timereg package中的src/aalen.c文件可能有一个类似于我的问题的解决方案。查看if ((start[c]<time) && (stop[c]>=time))行周围的代码。虽然这可能只是我对c /编程的无知。

2 个答案:

答案 0 :(得分:4)

我面临类似的选择,转而使用C ++来加速另一个问题,但我最终转向已经在C ++中有效实现并使用它们的R软件包。在这里,您想要的是一个名为 data.table 的包。

如果您是R的新手,这可能很难理解,但是通过小插曲here可以获得data.table包的良好文档。为了了解下面发生的事情,您可能会逐步了解我测试过的简化数据集的代码(请参阅答案的底部),并在对象更改值时监视对象。提高速度的关键是使用data.table的快速分配方法,并仅执行矢量化操作。

我的解决方案如下。请注意,我不确定您是否需要0,1,2值,但我很乐意将代码更改为生成0,1,如果这是您的意图。

require(data.table)

dataDT <- data.table(data[, c("id", "start", "stop", "transplant")])
# add a serial number for each id
dataDT[, idObs := 1:length(start), by = id ]
# needed because transplant is a factor in the heart dataset
dataDT[, transplant := as.integer(transplant)]

# create a "long" format data.table of subjects, observation number, and start/stop times
matDT <- data.table(subject = rep(1:n.sub, each = n.fail.time * max(dataDT$idObs)),
                    idObs = rep(1:max(dataDT$idObs), max(dataDT$idObs), n.sub * max(dataDT$idObs)),
                    fail.time = rep(fail.time, each = max(dataDT$idObs)))

# merge in start and stop times
setkey(matDT, subject, idObs)
setkey(dataDT, id, idObs)
matDT <- dataDT[matDT]

# eliminate missings (for which no 2nd observation took place)
matDT <- matDT[!is.na(transplant)]

# this replicates the "value" assignment in the loop
matDT[, value := transplant * ((start < fail.time) & (stop >= fail.time))]

# sum on the ids by fail time
matDT2 <- matDT[, list(matVal = sum(value)), by = list(id, fail.time)]

# convert to a matrix
mat2 <- matrix(matDT2$matVal, ncol = ncol(mat), byrow = TRUE, dimnames = list(1:n.sub, fail.time))

根据microbenchmark(),这比您的代码快很多倍,其中第一种方法是来自问题的代码:

        min         lq       mean     median         uq       max neval
 310.503535 339.364159 396.287178 354.292829 406.937216 762.28838   100
   7.113083   7.420517   9.436973   7.788479   9.426443  32.50355   100

为了显示输出,我在data对象的前六行测试了这个。这提供了一个很好的例子,因为第三和第四位患者(id = 3,4)在移植前后各有两次观察。

data <- heart[1:6, ]

然后我在您的mat对象中添加了行标签和列标签:

colnames(mat) <- fail.time
rownames(mat) <- 1:n.sub
mat
##   6 16 39 50
## 1 1  1  1  1
## 2 1  0  0  0
## 3 2  2  0  0
## 4 1  1  2  0

您可以在此处看到新的mat2相同:

mat2
##   6 16 39 50
## 1 1  1  1  1
## 2 1  0  0  0
## 3 2  2  0  0
## 4 1  1  2  0
all.equal(mat, mat2)
## [1] TRUE

答案 1 :(得分:3)

这是@KenBenoit解决方案的dplyr版本(请参阅dplyr.matrix函数)。下面是测试所有三种方法的代码。

library(dplyr)
library(data.table)
library(survival)
library(microbenchmark)
data(heart, package = "survival")
data <- heart

old.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  # Pre-fill matrix. Will be filled with covariate values.
  mat <- matrix(NA_real_, nrow = n.subjects, ncol = n.fail.time)
  # Run loop
  for(i in 1:n.subjects) { # Number of subjects
    data.subject <- data[data$id == i, ] # subsetting here provides nice speed-up
    for(j in 1:n.fail.time) { # Number of failure times.
      value <- subset(data.subject, (start < fail.time[j]) & (stop >= fail.time[j]), select = transplant, drop = TRUE)
      if(length(value) == 0) { # An early event or censor will return empty value. Assign to zero.
        mat[i, j] <- 0
      }
      else {
        mat[i, j] <- value # True value
      }
    }
  }
  mat
}

dplyr.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  # add a serial number for each id
  data <- data %>% group_by(id) %>% mutate(id.serial = 1:length(start))
  # needed because transplant is a factor in the heart dataset
  data$transplant <- as.integer(data$transplant)
  # create a "long" format data.frame of subjects, observation number, and start/stop times
  data.long <- data.frame(
    id = rep(1:n.subjects, each = n.fail.time * max(data$id.serial)),
    id.serial = rep(1:max(data$id.serial), max(data$id.serial), n.subjects * max(data$id.serial)),
    fail.time = rep(fail.time, each = max(data$id.serial))
  )
  # merge in start and stop times
  data.merge <- left_join(data.long, data[, c("start", "stop", "transplant", "id", "id.serial")], by = c("id", "id.serial"))
  # eliminate missings (for which no 2nd observation took place)
  data.merge <- na.omit(data.merge)
  # this replicates the "value" assignment in the loop
  data.merge <- data.merge %>% mutate(value = transplant * ((start < fail.time) & (stop >= fail.time)))
  # sum on the ids by fail time
  data.merge <- data.merge %>% group_by(id, fail.time) %>% summarise(value = sum(value))
  # convert to a matrix
  data.matrix <- matrix(data.merge$value, ncol = n.fail.time, byrow = TRUE, dimnames = list(1:n.subjects, fail.time))
  data.matrix
}

data.table.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  dataDT <- data.table(data[, c("id", "start", "stop", "transplant")])
  # add a serial number for each id
  dataDT[, idObs := 1:length(start), by = id ]
  # needed because transplant is a factor in the heart dataset
  dataDT[, transplant := as.integer(transplant)]
  # create a "long" format data.table of subjects, observation number, and start/stop times
  matDT <- data.table(subject = rep(1:n.subjects, each = n.fail.time * max(dataDT$idObs)),
                      idObs = rep(1:max(dataDT$idObs), max(dataDT$idObs), n.subjects * max(dataDT$idObs)),
                      fail.time = rep(fail.time, each = max(dataDT$idObs)))
  # merge in start and stop times
  setkey(matDT, subject, idObs)
  setkey(dataDT, id, idObs)
  matDT <- dataDT[matDT]
  # eliminate missings (for which no 2nd observation took place)
  matDT <- matDT[!is.na(transplant)]
  # this replicates the "value" assignment in the loop
  matDT[, value := transplant * ((start < fail.time) & (stop >= fail.time))]
  # sum on the ids by fail time
  matDT2 <- matDT[, list(matVal = sum(value)), by = list(id, fail.time)]
  # convert to a matrix
  mat2 <- matrix(matDT2$matVal, ncol = n.fail.time, byrow = TRUE, dimnames = list(1:n.subjects, fail.time))
  mat2
}

all(dplyr.matrix(data) == old.matrix(data))
all(dplyr.matrix(data) == data.table.matrix(data))

microbenchmark(
  old.matrix(data),
  dplyr.matrix(data),
  data.table.matrix(data),
  times = 50
)

microbenchmark的输出:

Unit: milliseconds
                    expr        min         lq      mean    median        uq       max neval cld
        old.matrix(data) 325.949687 328.102482 333.20923 329.39368 331.28305 373.44774    50   c
      dplyr.matrix(data)  17.586146  18.317833  20.04662  18.95724  19.62431  60.15858    50  b 
 data.table.matrix(data)   9.464045   9.892281  10.72819  10.29394  11.44812  12.67738    50 a  

上述结果与大约100个观测值的数据集相对应。当我在大约1000次观测的数据集上测试时,data.table开始拉开更多。

Unit: milliseconds
                    expr        min         lq       mean     median        uq        max neval cld
        old.matrix(data) 13095.7836 13114.1858 13162.5019 13134.0735 13150.217 13318.2496     5   c
      dplyr.matrix(data)  1067.1942  1075.5291  1149.0789  1166.8951  1197.998  1237.7787     5  b 
 data.table.matrix(data)   104.5133   155.2074   159.6794   159.6364   166.764   212.2758     5 a  

data.table现在是胜利者。