如何向依赖的For-Loops进行矢量化

时间:2013-08-27 18:04:37

标签: matlab for-loop vectorization

我正在开发一个函数,它将1xn向量x作为输入并返回一个nxn矩阵L
我想通过矢量化循环来加快速度,但是有一个令我困惑的问题:循环索引b取决于循环索引a。任何帮助将不胜感激。

x = x(:); 
n = length(x);
L = zeros(n, n);
for a = 1 : n,
    for b = 1 : a-1,
        c = b+1 : a-1;
        if all(x(c)' < x(b) + (x(a) - x(b)) * ((b - c)/(b-a))),
            L(a,b) = 1;
        end
    end
end

3 个答案:

答案 0 :(得分:1)

通过快速测试,看起来你只对下三角形做了些什么。您可以使用与ind2subarrayfun类似的丑陋技巧进行矢量化

tril_lin_idx = find(tril(ones(n), -1));
[A, B] = ind2sub([n,n], tril_lin_idx);
C = arrayfun(@(a,b) b+1 : a-1, A, B, 'uniformoutput', false); %cell array
f = @(a,b,c) all(x(c{:})' < x(b) + (x(a) - x(b)) * ((b - c{:})/(b-a)));
L = zeros(n, n);
L(tril_lin_idx) = arrayfun(f, A, B, C);

我无法测试它,因为我没有x而且我不知道预期的结果。我通常喜欢矢量化解决方案,但这可能会推动它有点过多:)。我会坚持你明确的for循环,这可能会更加清晰,Matlab的JIT应该能够轻松加速。您可以使用L(a,b) = all(...)替换if。

<强> EDIT1

更新版本,以防止在~ n^3上浪费C空间:

tril_lin_idx = find(tril(ones(n), -1));
[A, B] = ind2sub([n,n], tril_lin_idx);
c = @(a,b) b+1 : a-1;
f = @(a,b) all(x(c(a, b))' < x(b) + (x(a) - x(b)) * ((b - c(a, b))/(b-a)));
L = zeros(n, n);
L(tril_lin_idx) = arrayfun(f, A, B);

<强> EDIT2

轻微变体,它不使用ind2sub,如果ba上更复杂,则应该更容易修改。我为了速度而内联c,似乎特别是调用函数句柄很昂贵。

[A,B] = ndgrid(1:n);
v = B<A; % which elements to evaluate
f = @(a,b) all(x(b+1:a-1)' < x(b) + (x(a) - x(b)) * ((b - (b+1:a-1))/(b-a)));
L = false(n);
L(v) = arrayfun(f, A(v), B(v));

答案 1 :(得分:1)

如果我正确理解您的问题,L(a, b) == 1如果对于任何带有&lt; c&lt; b,(c,x(c))在连接(a,x(a))和(b,x(b))的线“下方”,右边?

这不是矢量化,但我找到了另一种方法。而不是将所有c与&lt; c&lt; b对于每个新b,我在(a,b)中保存了从a到c的最大斜率,并将其用于(a,b + 1)。 (我只试过一个方向,但我认为使用两个方向也是可能的。)

x = x(:);
n = length(x);
L = zeros(n);

for a = 1:(n - 1)
  L(a, a + 1) = 1;
  maxSlope = x(a + 1) - x(a);
  for b = (a + 2):n
    currSlope = (x(b) - x(a)) / (b - a);
    if currSlope > maxSlope
      maxSlope = currSlope;
      L(a, b) = 1;
    end
  end
end

我不知道您的数据,但是对于一些随机数据,结果与原始代码相同(使用转置)。

答案 2 :(得分:1)

一个深奥的答案:你可以从1:n开始计算每个a,b,c,排除不关心,然后沿c维度进行全部计算。

[a, b, c] = ndgrid(1:n, 1:n, 1:n);

La = x(c)' < x(b) + (x(a) - x(b)) .* ((b - c)./(b-a));
La(b >= a | c <= b | c >= a) = true;

L = all(La, 3);

虽然jit可能会对for循环做得很好,因为它们做的很少。

编辑:仍然使用所有内存,但数学较少

[A, B, C] = ndgrid(1:n, 1:n, 1:n);

valid = B < A & C > B & C < A;
a = A(valid); b = B(valid); c = C(valid);

La = true(size(A));
La(valid) = x(c)' < x(b) + (x(a) - x(b)) .* ((b - c)./(b-a));
L = all(La, 3);

Edit2:替换最后一行,添加无元素c为真的

L = all(La,3) | ~any(valid,3);