使用Rcpp复制dplyr :: group_by的索引功能

时间:2018-04-15 15:52:11

标签: c++ r dplyr rcpp

作为练习,我正在尝试使用Rcpp和C ++来获取分组索引,就像dplyr::group_by提供的那样。这些是与数据中每个组对应的行号(从0开始)。

以下是指数的示例。

x <- sample(1:3, 10, TRUE)
x
# [1] 3 3 3 1 3 1 3 2 3 2

df <- data.frame(x)
attr(dplyr::group_by(df, x), "indices")
#[[1]]
#[1] 3 5
#
#[[2]]
#[1] 7 9
#
#[[3]]
#[1] 0 1 2 4 6 8

到目前为止,使用标准库std::unordered_multimap,我想出了以下内容:

// [[Rcpp::plugins(cpp11)]]

#include <Rcpp.h>
using namespace Rcpp;

typedef std::vector<int> rowvec;

// [[Rcpp::export]]
std::vector<rowvec> rowlist(std::vector<int> x)
{
    std::unordered_multimap<int, int> rowmap;
    for (size_t i = 0; i < x.size(); i++)
    {
        rowmap.insert({ x[i], i });
    }

    std::vector<rowvec> rowlst;
    for (size_t i = 0; i < rowmap.bucket_count(); i++)
    {
        if (rowmap.begin(i) != rowmap.end(i))
        {
            rowvec v(rowmap.count(i));
            int b = 0;
            for (auto it = rowmap.begin(i); it != rowmap.end(i); ++it, b++)
            {
               v[b] = it->second;
            }
            rowlst.push_back(v);
        }
    }
    return rowlst;
}

在单个变量上运行此操作会导致

rowlist(x)
#[[1]]
#[1] 5 3
#
#[[2]]
#[1] 9 7
#
#[[3]]
#[1] 8 6 4 2 1 0

除了反向排序,这看起来不错。但是,我无法弄清楚如何扩展它来处理:

  • 不同的数据类型;该类型目前已硬编码到函数
  • 多个分组变量

std::unordered_multimapgroup_by相比也相当慢,但稍后我会处理。)任何帮助都会受到赞赏。

1 个答案:

答案 0 :(得分:2)

我已经对这个问题进行了一段时间的研究,我的结论是,至少可以说这很难。为了复制dplyr::group_by的魔力,你将不得不编写几个类,并设置一个非常灵活的散列函数来处理各种数据类型和不同数量的列。我已经搜索了dplyr源代码,看起来如果你按照ChunkMapIndex的创建,你会得到更好的理解。

说到数据类型,我甚至不确定使用std::unordered_multimap可以获得您想要的内容,因为它是不明智的,difficult可以使用double/float数据类型作为你的钥匙。

考虑到所提到的所有挑战,下面的代码将使用整数类型生成与attr(dplyr::group_by(df, x), "indices")相同的输出。我已经将其设置为让您开始考虑如何处理不同的数据类型。它使用带有辅助函数的模板化方法,因为它是处理不同数据类型的简单有效的解决方案。辅助函数与Dirk提供的链接中的函数非常相似。

// [[Rcpp::plugins(cpp11)]]

#include <Rcpp.h>
#include <string>
using namespace Rcpp;

typedef std::vector<int> rowvec;
typedef std::vector<rowvec> rowvec2d;

template <typename T>
rowvec2d rowlist(std::vector<T> x) {

    std::unordered_multimap<T, int> rowmap;
    for (int i = 0; i < x.size(); i++)
        rowmap.insert({ x[i], i });

    rowvec2d rowlst;

    for (int i = 0; i < rowmap.bucket_count(); i++) {
        if (rowmap.begin(i) != rowmap.end(i)) {
            rowvec v(rowmap.count(i));
            int b = 0;
            for (auto it = rowmap.begin(i); it != rowmap.end(i); ++it, b++)
                v[b] = it->second;

            rowlst.push_back(v);
        }
    }

    return rowlst;
}

template <typename T>
rowvec2d tempList(rowvec2d myList, std::vector<T> v) {

    rowvec2d vecOut;

    if (myList.size() > 0) {
        for (std::size_t i = 0; i < myList.size(); i++) {
            std::vector<T> vecPass(myList[i].size());
            for (std::size_t j = 0; j < myList[i].size(); j++)
                vecPass[j] = v[myList[i][j]];

            rowvec2d vecTemp = rowlist(vecPass);
            for (std::size_t j = 0; j < vecTemp.size(); j++) {
                rowvec myIndex(vecTemp[j].size());
                for (std::size_t k = 0; k < vecTemp[j].size(); k++)
                    myIndex[k] = myList[i][vecTemp[j][k]];

                vecOut.push_back(myIndex);
            }
        }
    } else {
        vecOut = rowlist(v);
    }

    return vecOut;
}

// [[Rcpp::export]]
rowvec2d rowlistMaster(DataFrame myDF) {

    DataFrame::iterator itDF;
    rowvec2d result;

    for (itDF = myDF.begin(); itDF != myDF.end(); itDF++) {
        switch(TYPEOF(*itDF)) {
            case INTSXP: {
                result = tempList(result, as<std::vector<int> >(*itDF));
                break;
            }
            default: {
                stop("v must be of type integer");
            }
        }
    }

    return result;
}

它适用于多个分组变量,但速度并不快。

set.seed(101)
x <- sample(1:5, 10^4, TRUE)
y <- sample(1:5, 10^4, TRUE)
w <- sample(1:5, 10^4, TRUE)
z <- sample(1:5, 10^4, TRUE)
df <- data.frame(x,y,w,z)

identical(attr(dplyr::group_by(df, x, y, w, z), "indices"), rowlistMaster(df))
[1] TRUE

library(microbenchmark)
microbenchmark(dplyr = attr(dplyr::group_by(df, x, y, w, z), "indices"),
               challenge = rowlistMaster(df))
Unit: milliseconds
     expr       min        lq       mean     median         uq        max neval
    dplyr  2.693624  2.900009   3.324274   3.192952   3.535927   6.827423   100
challenge 53.905133 70.091335 123.131484 141.414806 149.923166 190.010468   100