查找每行最大索引的最快方法

时间:2016-01-27 21:02:50

标签: r matrix range max min

我试图找到一种最佳方法来查找每行中最大值的索引。问题是我无法找到一种真正有效的方法。 一个例子:

Dummy <- matrix(runif(500000000,0,3), ncol = 10000)
> system.time(max.col(Dummy, "first"))
   user  system elapsed 
  5.532   0.075   5.599 
> system.time(apply(Dummy,1,which.max))
   user  system elapsed 
 14.638   0.210  14.828 
> system.time(rowRanges(Dummy))
   user  system elapsed 
  2.083   0.029   2.109 

我的主要问题是,与使用rowRanges函数计算最大值和最小值相比,为什么计算最大值的索引的速度超过2倍。有没有办法如何提高计算每行最大索引的性能?

5 个答案:

答案 0 :(得分:6)

扩展krlmlr的答案,一些基准:

在数据集上:

maxCol_R

maxCol_col是一个R-by-row循环,maxCol_row是一个C-by-row循环,microbenchmark::microbenchmark(max.col(Dummy, "first"), maxCol_R(Dummy), maxCol_col(Dummy), maxCol_row(Dummy), times = 30) #Unit: milliseconds # expr min lq median uq max neval # max.col(Dummy, "first") 1209.28408 1245.24872 1268.34146 1291.26612 1504.0072 30 # maxCol_R(Dummy) 1060.99994 1084.80260 1099.41400 1154.11213 1436.2136 30 # maxCol_col(Dummy) 86.52765 87.22713 89.00142 93.29838 122.2456 30 # maxCol_row(Dummy) 577.51613 583.96600 598.76010 616.88250 671.9191 30 all.equal(max.col(Dummy, "first"), maxCol_R(Dummy)) #[1] TRUE all.equal(max.col(Dummy, "first"), maxCol_col(Dummy)) #[1] TRUE all.equal(max.col(Dummy, "first"), maxCol_row(Dummy)) #[1] TRUE 是一个C-by-row循环。

maxCol_R = function(x)
{
    ans = rep_len(1L, nrow(x))
    mx = x[, 1L]

    for(j in 2:ncol(x)) {
        tmp = x[, j]
        wh = which(tmp > mx)

        ans[wh] = j
        mx[wh] = tmp[wh]
    }

    ans
} 

maxCol_col = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), *buf = (double *) R_alloc(nr, sizeof(double));
    for(int i = 0; i < nr; i++) buf[i] = R_NegInf;

    SEXP ans = PROTECT(allocVector(INTSXP, nr));
    int *pans = INTEGER(ans);

    for(int j = 0; j < nc; j++) {
        for(int i = 0; i < nr; i++) {
            if(px[i + j*nr] > buf[i]) {
                buf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

maxCol_row = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), *buf = (double *) R_alloc(nr, sizeof(double));
    for(int i = 0; i < nr; i++) buf[i] = R_NegInf;

    SEXP ans = PROTECT(allocVector(INTSXP, nr));
    int *pans = INTEGER(ans);

    for(int i = 0; i < nr; i++) {
        for(int j = 0; j < nc; j++) {
            if(px[i + j*nr] > buf[i]) {
                buf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

功能:

rangeCol = inline::cfunction(sig = c(x = "matrix"), body = '
    int nr = INTEGER(getAttrib(x, R_DimSymbol))[0], nc = INTEGER(getAttrib(x, R_DimSymbol))[1]; 
    double *px = REAL(x), 
           *maxbuf = (double *) R_alloc(nr, sizeof(double)),
           *minbuf = (double *) R_alloc(nr, sizeof(double));
    memcpy(maxbuf, &(px[0 + 0*nr]), nr * sizeof(double));
    memcpy(minbuf, &(px[0 + 0*nr]), nr * sizeof(double));

    SEXP ans = PROTECT(allocMatrix(INTSXP, nr, 2));
    int *pans = INTEGER(ans); 
    for(int i = 0; i < LENGTH(ans); i++) pans[i] = 1;

    for(int j = 1; j < nc; j++) {
        for(int i = 0; i < nr; i++) {
            if(px[i + j*nr] > maxbuf[i]) {
                maxbuf[i] = px[i + j*nr];
                pans[i] = j + 1;
            }
            if(px[i + j*nr] < minbuf[i]) {
                minbuf[i] = px[i + j*nr];
                pans[i + nr] = j + 1;
            }
        }
    }

    UNPROTECT(1);
    return(ans);
', language = "C")

set.seed(007); m = matrix(sample(24) + 0, 6, 4)
m
#     [,1] [,2] [,3] [,4]
#[1,]   24    7   23    6
#[2,]   10   17   21   11
#[3,]    3   22   20   14
#[4,]    2   18    1   15
#[5,]    5   19   12    8
#[6,]   16    4    9   13
rangeCol(m)
#     [,1] [,2]
#[1,]    1    4
#[2,]    3    1
#[3,]    2    1
#[4,]    2    3
#[5,]    2    1
#[6,]    1    2       

编辑 16年6月10日

稍微改变一下,找到max和min的指数:

{{1}}

答案 1 :(得分:5)

这是一个非常基本的Rcpp实现:

#include <Rcpp.h>

// [[Rcpp::export]]
Rcpp::NumericVector MaxCol(Rcpp::NumericMatrix m) {
    R_xlen_t nr = m.nrow(), nc = m.ncol(), i = 0;
    Rcpp::NumericVector result(nr);

    for ( ; i < nr; i++) {
        double current = m(i, 0);
        R_xlen_t idx = 0, j = 1;
        for ( ; j < nc; j++) {
            if (m(i, j) > current) {
                current = m(i, j);
                idx = j;
            }
        }
        result[i] = idx + 1;
    }
    return result;
}

/*** R

microbenchmark::microbenchmark(
    "Rcpp" = MaxCol(Dummy), 
    "R" = max.col(Dummy, "first"),
    times = 200L
)
#Unit: milliseconds
# expr      min       lq     mean   median       uq      max neval
# Rcpp 221.7777 224.7442 242.0089 229.6407 239.6339 455.9549   200
# R    513.4391 524.7585 562.7465 539.4829 562.3732 944.7587   200

*/

由于我的笔记本电脑没有足够的内存,我不得不将您的样本数据缩小一个数量级,但结果应该转换为原始样本数据:

Dummy <- matrix(runif(50000000,0,3), ncol = 10000)
all.equal(MaxCol(Dummy), max.col(Dummy, "first"))
#[1] TRUE

可以稍微更改一下,以返回每行minmax的索引:

// [[Rcpp::export]]
Rcpp::NumericMatrix MinMaxCol(Rcpp::NumericMatrix m) {
    R_xlen_t nr = m.nrow(), nc = m.ncol(), i = 0;
    Rcpp::NumericMatrix result(nr, 2);

    for ( ; i < nr; i++) {
        double cmin = m(i, 0), cmax = m(i, 0);
        R_xlen_t min_idx = 0, max_idx = 0, j = 1;
        for ( ; j < nc; j++) {
            if (m(i, j) > cmax) {
                cmax = m(i, j);
                max_idx = j;
            }
            if (m(i, j) < cmin) {
                cmin = m(i, j);
                min_idx = j;
            }
        }
        result(i, 0) = min_idx + 1;
        result(i, 1) = max_idx + 1;
    }
    return result;
}

答案 2 :(得分:3)

R将矩阵存储在column-major order中。因此,迭代通常会更快,因为一列的值在内存中彼此接近,并且将一次遍历缓存层次结构:

align-items: center;

这应该接近最快的Rcpp解决方案能够获得的最短时间。使用nav的任何解决方案都必须复制每个列/行,这可以在使用Rcpp时保存。您可以决定潜在的加速速度是否值得您努力。

答案 3 :(得分:2)

通常,在R中执行操作的最快方法是调用C,C ++或FORTRAN。

似乎matrixStats::rowRangesimplemented in C,这解释了为什么它是最快的。

如果你想进一步提高性能,可能会有一点点速度来修改rowRanges.c代码以忽略最小值并获得最大值,但我认为增益会非常小

答案 4 :(得分:1)

尝试使用STL算法和RcppArmadillo。

microbenchmark::microbenchmark(MaxColArmadillo(Dummy), #Using RcppArmadillo
                               MaxColAlgorithm(Dummy), #Using STL algorithm max_element
                               maxCol_col(Dummy), #Column processing
                               maxCol_row(Dummy)) #Row processing

Unit: milliseconds
                   expr       min        lq     mean    median       uq      max neval
 MaxColArmadillo(Dummy) 227.95864 235.01426 261.4913 250.17897 276.7593 399.6183   100
 MaxColAlgorithm(Dummy) 292.77041 345.84008 392.1704 390.66578 433.8009 552.2349   100
      maxCol_col(Dummy)  40.64343  42.41487  53.7250  48.10126  61.3781 128.4968   100
      maxCol_row(Dummy) 146.96077 158.84512 173.0941 169.20323 178.7959 272.6261   100

STL实施

#include <Rcpp.h>

// [[Rcpp::export]]

// Argument is a matrix ansd returns a 
// vector of max of each of the rows of the matrix

Rcpp::NumericVector MaxColAlgorithm(Rcpp::NumericMatrix m) {

//int numOfRows = m.rows();

//Create vector with 0 of size numOfRows
Rcpp::NumericVector total(m.rows());

  for(int i = 0; i < m.rows(); ++i)
  {
    //Create vector of the rows of matrix
    Rcpp::NumericVector rVec = m.row(i);

    //Apply STL max of elemsnts on the vector and store in a vector
    total(i) = *std::max_element(rVec.begin(), rVec.end());
  }

  return total;

}

RcppArmadillo实施

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;

// [[Rcpp::export]]
arma::mat MaxColArmadillo(arma::mat x) 
{
  //RcppArmadillo max function where dim = 1 means max of each row
  // of the matrix
  return(max(x,1)); 
}