快速计算一个向量的余弦相似度

时间:2017-03-30 04:55:18

标签: c++ r performance vector rcpp

我非常希望听到优化代码的想法,以计算向量x(长度为l)与n个其他向量(存储在任何结构中)的余弦相似度例如带有m行和n列的矩阵l

n的值通常远大于l的值。

我目前正在使用此自定义Rcpp函数来计算向量x与矩阵m的每一行的相似度:

library(Rcpp)
cppFunction('NumericVector cosine_x_to_m(NumericVector x, NumericMatrix m) {
  int nrows = m.nrow();
  NumericVector out(nrows);
  for (int i = 0; i < nrows; i++) {
    NumericVector y = m(i, _);
    out[i] = sum(x * y) / sqrt(sum(pow(x, 2.0)) * sum(pow(y, 2.0)));
  }
  return out;
}')

改变nl,我得到以下各种时间:

enter image description here

以下可重现的代码。

# Function to simulate data
sim_data <- function(l, n) {
  # Feature vector to be used for computing similarity
  x <- runif(l)

  # Matrix of feature vectors (1 per row) to compare against x
  m <- matrix(runif(n * l), nrow = n)

  list(x = x, m = m)
}

# Rcpp function to compute similarity of x to each row of m
library(Rcpp)
cppFunction('NumericVector cosine_x_to_m(NumericVector x, NumericMatrix m) {
  int nrows = m.nrow();
  NumericVector out(nrows);
  for (int i = 0; i < nrows; i++) {
    NumericVector y = m(i, _);
    out[i] = sum(x * y) / sqrt(sum(pow(x, 2.0)) * sum(pow(y, 2.0)));
  }
  return out;
}')    

# Timer function
library(microbenchmark)
timer <- function(l, n) {
  dat <- sim_data(l, n)
  microbenchmark(cosine_x_to_m(dat$x, dat$m))
}

# Results for grid of l and n
library(tidyverse)
results <- cross_d(list(l = seq(200, 1000, by = 200), n = seq(500, 4000, by = 500))) %>% 
  mutate(timings = map2(l, n, timer))

# Plot results
results_plot <- results %>%
  unnest(timings) %>% 
  mutate(time = time / 1000000) %>%  # Convert time to seconds
  group_by(l, n) %>% 
  summarise(mean = mean(time), ci = 1.96 * sd(time) / sqrt(n()))

pd <- position_dodge(width = 20)

results_plot %>% 
  ggplot(aes(n, mean, group= l)) +
  geom_line(aes(color = factor(l)), position = pd, size = 2) +
  geom_errorbar(aes(ymin = mean - ci, ymax = mean + ci), position = pd, width = 100) +
  geom_point(position = pd, size = 2) +
  scale_color_brewer(palette = "Blues") +
  theme_minimal() +
  labs(x = "n", y = "Seconds", color = "l") +
  ggtitle("Algorithm Runtime",
          subtitle = "Error bars represent 95% confidence intervals")

1 个答案:

答案 0 :(得分:1)

我正在使用Microsoft R(使用英特尔MKL),这使得矩阵乘法更快,但为了公平比较,我将其设置为单线程。

prepend_before_action :require_no_authentication, only: [ :edit, :update]

在我的测试中,这个纯R版本setMKLthreads(1) 比你的快两倍。

cosine_x_to_m

在C / C ++中重写cosine_x_to_m2 = function(x,m){ x = x / sqrt(crossprod(x)); return( as.vector((m %*% x) / sqrt(rowSums(m^2))) ); } 使其更快,大约比原版快四倍。

rowSums(m^2)

初步表现:

answer

最终版本表现:

Initial performance