提高马尔可夫链仿真MATLAB中randsample的效率。

时间:2014-06-06 09:58:55

标签: matlab random random-sample markov-chains

我正在使用matlab来模拟累积过程,其中几个随机游走并行地向阈值累积。为了选择在时间t将增加哪个随机游走,使用randsample。如果向量V表示活动的随机游走,向量P表示应该选择每个随机游走的概率,那么对randsample的调用如下所示:

randsample(V, 1, true, P);

问题在于模拟速度慢,而randsample是瓶颈。大约80%的运行时专门用于解析randsample调用。

是否有一种相对简单的方法可以提高randsample的效率?还有其他替代方案可以提高速度吗?

2 个答案:

答案 0 :(得分:2)

就像我在评论中提到的那样,瓶颈是由于您一次采样一个值这一事实造成的,如果您对randsample调用进行矢量化会更快(当然我假设概率向量是常数)。

这是一个快速的基准:

function testRandSample()
    v = 1:5;
    w = rand(numel(v),1); w = w ./ sum(w);
    n = 50000;

    % timeit
    t(1) = timeit(@() func1(v, w, n));
    t(2) = timeit(@() func2(v, w, n));
    t(3) = timeit(@() func3(v, w, n));
    disp(t)

    % check distribution of samples (should be close to w)
    tabulate(func1(v, w, n))
    tabulate(func2(v, w, n))
    tabulate(func3(v, w, n))
    disp(w*100)
end


function s = func1(v, w, n)
    s = randsample(v, n, true, w);
end

function s = func2(v, w, n)
    [~,idx] = histc(rand(n,1), [0;cumsum(w(:))./sum(w)]);
    s = v(idx);
end

function s = func3(v, w, n)
    cw = cumsum(w) / sum(w);
    s = zeros(n,1);
    for i=1:n
        s(i) = find(rand() <= cw, 1, 'first');
    end
    s = v(s);

    %s = v(arrayfun(@(~)find(rand() <= cw, 1, 'first'), 1:n));
end

输出(带注释):

% measured elapsed times for func1/2/3 respectively
  0.0016    0.0015    0.0790

% distribution of random sample from func1
  Value    Count   Percent
      1     4939      9.88%
      2    15049     30.10%
      3     7450     14.90%
      4    11824     23.65%
      5    10738     21.48%

% distribution of random sample from func2
  Value    Count   Percent
      1     4814      9.63%
      2    15263     30.53%
      3     7479     14.96%
      4    11743     23.49%
      5    10701     21.40%

% distribution of random sample from func3
  Value    Count   Percent
      1     4985      9.97%
      2    15132     30.26%
      3     7275     14.55%
      4    11905     23.81%
      5    10703     21.41%

% true population distribution
    9.7959
   30.4149
   14.7414
   23.4949
   21.5529

如您所见,randsample非常优化。您在代码中观察到的瓶颈可能是由于我解释的缺乏矢量化。

要了解它的速度有多慢,请将func1替换为循环版本,一次取样一个值:

function s = func1(v, w, n)
    s = zeros(n,1);
    for i=1:n
        s(i) = randsample(v, 1, true, w);
    end
end

答案 1 :(得分:1)

也许这会更快:

find(rand <= cumsum(P), 1) %// gives the same as randsample(V, 1, true, P)

我假设P是概率,即它们的总和是1。否则标准化P

find(rand <= cumsum(P)/sum(P), 1) %// gives the same as randsample(V, 1, true, P)

如果P始终相同,请预先计算cumsum(P)/sum(P)以节省时间:

cp = cumsum(P)/sum(P); %// precompute (just once)
find(rand <= cP, 1) %// gives the same as randsample(V, 1, true, P)