我正在开发一个函数,它将1xn向量x
作为输入并返回一个nxn矩阵L
。
我想通过矢量化循环来加快速度,但是有一个令我困惑的问题:循环索引b
取决于循环索引a
。任何帮助将不胜感激。
x = x(:);
n = length(x);
L = zeros(n, n);
for a = 1 : n,
for b = 1 : a-1,
c = b+1 : a-1;
if all(x(c)' < x(b) + (x(a) - x(b)) * ((b - c)/(b-a))),
L(a,b) = 1;
end
end
end
答案 0 :(得分:1)
通过快速测试,看起来你只对下三角形做了些什么。您可以使用与ind2sub
和arrayfun
类似的丑陋技巧进行矢量化
tril_lin_idx = find(tril(ones(n), -1));
[A, B] = ind2sub([n,n], tril_lin_idx);
C = arrayfun(@(a,b) b+1 : a-1, A, B, 'uniformoutput', false); %cell array
f = @(a,b,c) all(x(c{:})' < x(b) + (x(a) - x(b)) * ((b - c{:})/(b-a)));
L = zeros(n, n);
L(tril_lin_idx) = arrayfun(f, A, B, C);
我无法测试它,因为我没有x
而且我不知道预期的结果。我通常喜欢矢量化解决方案,但这可能会推动它有点过多:)。我会坚持你明确的for循环,这可能会更加清晰,Matlab的JIT应该能够轻松加速。您可以使用L(a,b) = all(...)
替换if。
<强> EDIT1 强>
更新版本,以防止在~ n^3
上浪费C
空间:
tril_lin_idx = find(tril(ones(n), -1));
[A, B] = ind2sub([n,n], tril_lin_idx);
c = @(a,b) b+1 : a-1;
f = @(a,b) all(x(c(a, b))' < x(b) + (x(a) - x(b)) * ((b - c(a, b))/(b-a)));
L = zeros(n, n);
L(tril_lin_idx) = arrayfun(f, A, B);
<强> EDIT2 强>
轻微变体,它不使用ind2sub,如果b
在a
上更复杂,则应该更容易修改。我为了速度而内联c
,似乎特别是调用函数句柄很昂贵。
[A,B] = ndgrid(1:n);
v = B<A; % which elements to evaluate
f = @(a,b) all(x(b+1:a-1)' < x(b) + (x(a) - x(b)) * ((b - (b+1:a-1))/(b-a)));
L = false(n);
L(v) = arrayfun(f, A(v), B(v));
答案 1 :(得分:1)
如果我正确理解您的问题,L(a, b) == 1
如果对于任何带有&lt; c&lt; b,(c,x(c))在连接(a,x(a))和(b,x(b))的线“下方”,右边?
这不是矢量化,但我找到了另一种方法。而不是将所有c与&lt; c&lt; b对于每个新b,我在(a,b)中保存了从a到c的最大斜率,并将其用于(a,b + 1)。 (我只试过一个方向,但我认为使用两个方向也是可能的。)
x = x(:);
n = length(x);
L = zeros(n);
for a = 1:(n - 1)
L(a, a + 1) = 1;
maxSlope = x(a + 1) - x(a);
for b = (a + 2):n
currSlope = (x(b) - x(a)) / (b - a);
if currSlope > maxSlope
maxSlope = currSlope;
L(a, b) = 1;
end
end
end
我不知道您的数据,但是对于一些随机数据,结果与原始代码相同(使用转置)。
答案 2 :(得分:1)
一个深奥的答案:你可以从1:n开始计算每个a,b,c,排除不关心,然后沿c维度进行全部计算。
[a, b, c] = ndgrid(1:n, 1:n, 1:n);
La = x(c)' < x(b) + (x(a) - x(b)) .* ((b - c)./(b-a));
La(b >= a | c <= b | c >= a) = true;
L = all(La, 3);
虽然jit可能会对for循环做得很好,因为它们做的很少。
编辑:仍然使用所有内存,但数学较少
[A, B, C] = ndgrid(1:n, 1:n, 1:n);
valid = B < A & C > B & C < A;
a = A(valid); b = B(valid); c = C(valid);
La = true(size(A));
La(valid) = x(c)' < x(b) + (x(a) - x(b)) .* ((b - c)./(b-a));
L = all(La, 3);
Edit2:替换最后一行,添加无元素c为真的
L = all(La,3) | ~any(valid,3);