我在下面写了一个简单的函数:
mcs <- function(v) { ifelse(sum((diff(sort(v)) > 6) > 0), NA, sd(v)) }
应该采用矢量,对其进行排序,然后检查每个连续差异中是否存在大于6的差异。如果差异大于6,则返回NA;如果不存在,则返回标准偏差。
我想在数据表的所有行中应用此函数(仅选择某些列),然后将每行的返回值作为新列条目追加到数据表中。
例如,给定一个像这样的数据表
> dat <- data.table(A=c(1,2,3,4,5), B=c(2,3,4,10,6), C=c(3,4,10,6,8),
D=c(3,3,3,3,3))
> dat
A B C D
1: 1 2 3 3
2: 2 3 4 3
3: 3 4 10 3
4: 4 10 6 3
5: 5 6 8 3
我想生成下面的输出。 (我在每行的第2,3和4列应用了函数。)
> dat
A B C D sd
1: 1 2 3 3 0.5773503
2: 2 3 4 3 0.5773503
3: 3 4 10 3 3.7859389
4: 4 10 6 3 3.5118846
5: 5 6 8 3 2.5166115
我了解到可以使用以下方法通过行操作完成数据表:
> dat[, sd:=apply(.SD, 1, mcs), .SDcols=(c(2,3,4))]
除了太慢之外,这种方法有效。我必须在几个大型数据表上执行此操作,并编写了一个脚本来执行此操作。但是,它仅适用于较小的数据表。对于有大约300,000行的表,它会在几秒钟内完成,但是当我尝试使用一个有大约8亿行的表时,我的程序还没有完成。我试过等了两个小时,我认为R打破了什么,因为控制台只是冻结了。我已经尝试过几次运行脚本,它总是正确完成前几个较小的表(我让程序将表写入文件进行检查)但是当它到达大数据表时,它永远不会完成。我在计算群集上运行它,所以我绝对不认为这是硬件限制。可能是糟糕的代码。
我认为瓶颈是应用中的循环,但我不知道如何让它更快。我对R很新,所以我不确定如何优化我的代码。我在互联网上看过很多关于矢量化的帖子,我想也许如果我能同时将我的功能应用到每一行它会更快,但我不知道该怎么做。请帮忙。
修改
抱歉,我在复制mcs
功能时犯了一个错误。我已经更新了它。
编辑2
对于那些感兴趣的人,我最终将表分成两半并分别对每一半进行操作,这对我有用。
答案 0 :(得分:4)
如果你真的需要速度,一如既往最好使用Rcpp转向C ++,这为我们提供了超过100倍的解决方案。
我确实做了一些不同的示例数据来测试这个有1000行而不是5行的数据:
set.seed(123)
dat <- data.table(A = rnorm(1e3, sd=4), B = rnorm(1e3, sd=4), C = rnorm(1e3, sd=4),
D = rnorm(1e3, sd=4), E = rnorm(1e3, sd=4))
我使用以下C ++代码来执行与函数相同的操作,但现在循环使用C ++而不是R through apply完成,这节省了大量时间:
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
NumericVector mcs2(DataFrame x) {
int n = x.nrows();
int m = x.size();
NumericMatrix mat(n, m);
for ( int j = 0; j < m; ++j ) {
mat(_, j) = NumericVector(x[j]);
}
NumericVector result(n);
for ( int i = 0; i < n; ++i ) {
NumericVector tmp = mat(i, _);
std::sort(tmp.begin(), tmp.end());
bool do_sd = true;
for ( int j = 1; j < m; ++j ) {
if ( tmp[j] - tmp[j-1] > 6.0 ) {
result[i] = NA_REAL;
do_sd = false;
break;
}
}
if ( do_sd ) {
result[i] = sd(tmp);
}
do_sd = true;
}
return result;
}
我们可以确保它返回相同的值:
all.equal(apply(dat[, 2:4], 1, mcs1), mcs2(dat[,2:4]))
[1] TRUE
现在让我们进行基准测试:
benchmark(mcs1 = dat[, sd:=apply(.SD, 1, mcs1), .SDcols=(c(2,3,4))],
mcs2 = dat[, sd:=mcs2(.SD), .SDcols=(c(2,3,4))],
order = 'relative',
columns = c('test', 'elapsed', 'relative', 'user.self'))
test elapsed relative user.self
2 mcs2 0.19 1.000 0.183
1 mcs1 21.34 112.316 20.044
作为通过Rcpp使用C ++代码的介绍,我建议Hadley Wickham的高级R this chapter。如果你打算用Rcpp做进一步的事情,我强烈建议你也阅读官方文档和小插曲但是Wickham的书可能更适合初学者作为起点。为了您的目的,您只需要启动并运行Rcpp,以便您可以编译上面的代码。
要使此代码适合您,如果您还没有Rcpp软件包,则需要它。您可以通过运行
来获取包install.packages(Rcpp)
来自R.注意你还需要一个编译器;如果您使用的是基于Debian的Linux系统(如Ubuntu),则可以运行
sudo apt install r-base-dev
来自终端。如果您使用的是Mac或Windows,请查看here以获取有关获取此设置的一些说明,或查看上面链接的Wickham章节。
安装Rcpp后,将上面的C ++代码保存到文件中。假设我们的示例文件名为“SOanswer.cpp”。然后,您可以通过在R脚本中添加以下两行来使其mcs2()
函数可用:
library(Rcpp)
sourceCpp("SOanswer.cpp") # assuming the file is in your working directory
就是这样!现在你的R脚本可以调用mcs2()
并且运行得更快。如果你想了解更多关于Rcpp的信息,除了上面的Wickham章节,我还会查看参考手册以及来自RStudio的可用的小插曲here,this page(其中包含大量的链接,一些这些链接到这里),你也可以找到一些非常有用的东西,看看Rcpp gallery。