查找大矩阵行之间的最小距离:分配限制错误

时间:2016-10-31 10:11:06

标签: r distance allocation euclidean-distance

我想计算一个大矩阵的所有行之间的距离。对于每一行,我需要找到另一行,其距离最短。最终输出应该是一个列表,其中包含距离最短的行的ID(请参阅下面示例中的low_dis_ids)。

我能够找到小样本的解决方案(例如下面的例子)。但是,我无法使用更大的样本量执行这些步骤,因为距离矩阵变得很大。有没有办法省略这么大的矩阵?我只需要带有ID的列表(如low_dis_ids)。

可重复的例子:

isUserAllowedToLeave()

2 个答案:

答案 0 :(得分:1)

您绝对可以避免使用dist创建的大型矩阵。一种这样的解决方案是一次计算一行的距离,我们可以编写给定整个数据帧的函数,并且行id找到哪个行对应于最小距离。例如:

f = function(rowid, whole){
  d = colSums((whole[rowid,] - t(whole))^2) # calculate distance
  d[rowid] = Inf # replace the zero
  which.min(d)
}

colSums功能相当优化,因此速度相对较快。我怀疑which.min也是一种稍微快一些,可能更简洁的方法来循环通过距离矢量。

为了制作一个解决方案然后应用于任何这样的数据框我写了另一个短函数,它将这个应用于每一行并给你一个行ID的向量

mindists = function(dat) do.call(c,lapply(1:nrow(dat),f,whole = as.matrix(dat)))

如果您想要列表而不是矢量,只需省略do.call函数即可。我这样做是为了更容易检查输出是否符合您的预期。

all(do.call(c,low_dis_ids) == mindists(data_100))
[1] TRUE

这也适用于我的笔记本电脑上的更大范例。由于您正在nrow(data)进行f调用,因此速度并不快,但它确实避免了创建一个大对象。可能有更好的解决方案,但这是第一个出现在脑海中的解决方案。希望有所帮助。

编辑:

编辑,因为Roland还有一个额外的C ++答案 我对较小的数据集进行了快速基准测试。在这种情况下,C ++的答案肯定更快。 如果你纯粹是一个R程序员(不需要学习C ++和RCpp),我认为更容易理解这个答案的一些额外销售推销。使用lapply的并行替换,R版本很容易并行化。我要注意的是,这并不是要取消Rolands的回答,我个人喜欢Rcpp,只是为这个问题的未来读者提供额外的信息。

答案 1 :(得分:1)

使用Rcpp,因为基础R解决方案太慢了:

library(Rcpp)
library(inline)
cppFunction(
"  IntegerVector findLowestDist(const NumericMatrix X) {
    const int n = X.nrow();
    const int m = X.ncol();
    IntegerVector minind(n);
    NumericVector minsqdist(n);
    double d;
    for (int i = 0; i < n; ++i) {
      if (i == 0) {
        d = 0;
        for (int k = 0; k < m; ++k) {
          d += pow(X(i, k) - X(1, k), 2.0);

        }
        minsqdist(i) = d;
        minind(i) = 1;
      } else {
        d = 0;
        for (int k = 0; k < m; ++k) {
          d += pow(X(i, k) - X(0, k), 2.0);

        }
        minsqdist(i) = d;
        minind(i) = 0;
      }

      for (int j = 1; j < n; ++j) {
        if (i == j) continue;
        d = 0;
        for (int k = 0; k < m; ++k) {
          d += pow(X(i, k) - X(j, k), 2.0);

        }
        if (d < minsqdist(i)) {
          minsqdist(i) = d;
          minind(i) = j;
        }
      }
    }
    return minind + 1;
  }"
)

all.equal(findLowestDist(as.matrix(data_100)),
          unlist(low_dis_ids))
#[1] TRUE

findLowestDist(as.matrix(data_100000))
#works

可能会改进算法。