我正在使用matlab来模拟累积过程,其中几个随机游走并行地向阈值累积。为了选择在时间t将增加哪个随机游走,使用randsample。如果向量V表示活动的随机游走,向量P表示应该选择每个随机游走的概率,那么对randsample的调用如下所示:
randsample(V, 1, true, P);
问题在于模拟速度慢,而randsample是瓶颈。大约80%的运行时专门用于解析randsample调用。
是否有一种相对简单的方法可以提高randsample的效率?还有其他替代方案可以提高速度吗?
答案 0 :(得分:2)
就像我在评论中提到的那样,瓶颈是由于您一次采样一个值这一事实造成的,如果您对randsample
调用进行矢量化会更快(当然我假设概率向量是常数)。
这是一个快速的基准:
function testRandSample()
v = 1:5;
w = rand(numel(v),1); w = w ./ sum(w);
n = 50000;
% timeit
t(1) = timeit(@() func1(v, w, n));
t(2) = timeit(@() func2(v, w, n));
t(3) = timeit(@() func3(v, w, n));
disp(t)
% check distribution of samples (should be close to w)
tabulate(func1(v, w, n))
tabulate(func2(v, w, n))
tabulate(func3(v, w, n))
disp(w*100)
end
function s = func1(v, w, n)
s = randsample(v, n, true, w);
end
function s = func2(v, w, n)
[~,idx] = histc(rand(n,1), [0;cumsum(w(:))./sum(w)]);
s = v(idx);
end
function s = func3(v, w, n)
cw = cumsum(w) / sum(w);
s = zeros(n,1);
for i=1:n
s(i) = find(rand() <= cw, 1, 'first');
end
s = v(s);
%s = v(arrayfun(@(~)find(rand() <= cw, 1, 'first'), 1:n));
end
输出(带注释):
% measured elapsed times for func1/2/3 respectively
0.0016 0.0015 0.0790
% distribution of random sample from func1
Value Count Percent
1 4939 9.88%
2 15049 30.10%
3 7450 14.90%
4 11824 23.65%
5 10738 21.48%
% distribution of random sample from func2
Value Count Percent
1 4814 9.63%
2 15263 30.53%
3 7479 14.96%
4 11743 23.49%
5 10701 21.40%
% distribution of random sample from func3
Value Count Percent
1 4985 9.97%
2 15132 30.26%
3 7275 14.55%
4 11905 23.81%
5 10703 21.41%
% true population distribution
9.7959
30.4149
14.7414
23.4949
21.5529
如您所见,randsample
非常优化。您在代码中观察到的瓶颈可能是由于我解释的缺乏矢量化。
要了解它的速度有多慢,请将func1
替换为循环版本,一次取样一个值:
function s = func1(v, w, n)
s = zeros(n,1);
for i=1:n
s(i) = randsample(v, 1, true, w);
end
end
答案 1 :(得分:1)
也许这会更快:
find(rand <= cumsum(P), 1) %// gives the same as randsample(V, 1, true, P)
我假设P
是概率,即它们的总和是1
。否则标准化P
:
find(rand <= cumsum(P)/sum(P), 1) %// gives the same as randsample(V, 1, true, P)
如果P
始终相同,请预先计算cumsum(P)/sum(P)
以节省时间:
cp = cumsum(P)/sum(P); %// precompute (just once)
find(rand <= cP, 1) %// gives the same as randsample(V, 1, true, P)