Matlab - 加速嵌套For循环

时间:2010-01-06 22:05:34

标签: matlab

一个简单的问题,但我对MATLAB不太满意。我有向量x,(n x 1)y,(m x 1)和w = [x;y]。我想将M(n + m x 1)定义为M(i)= x的元素数小于或等于w(i)(w被排序)。这只是没有削减它:

N = n + m;
M = zeros(N,1);
for i = 1:N
  for j = 1:n
    if x(j) <= w(i)
      M(i) = M(i) + 1;
    end
  end
end

这不是一种特别聪明的方法,我的一些数据向量m和n大约是100000。

谢谢!

5 个答案:

答案 0 :(得分:9)

这可能看起来很神秘,但它应该给你与嵌套循环相同的结果:

M = histc(-x(:)',[-fliplr(w(:)') inf]);
M = cumsum(fliplr(M(1:N)))';

以上假设w已按升序排序。

说明

您的排序向量w可以被视为bin边缘,用于创建具有HISTC函数的直方图。计算每个bin中(即边缘之间)的值的数量后,使用CUMSUM函数对这些bin的累积总和将为您提供向量M。上面的代码看起来如此混乱(带有否定和函数FLIPLR)的原因是因为你想在x 中找到小于或等于中每个值的值w,但函数HISTC以下列方式存储数据:

  如果n(k)

x(i)会计算值edges(k) <= x(i) < edges(k+1)

请注意,小于用于每个bin的上限。您可能希望翻转行为,以便根据规则edges(k) < x(i) <= edges(k+1)进行分箱,这可以通过否定要分箱的值来实现,否定边缘,翻转边缘(因为边缘输入到HISTC必须单调非减少),然后翻转返回的bin计数。值inf用作边值,用于计算小于第一个bin中w中最低值的所有内容。

如果您想在x中找到小于 w中每个值的值,则代码会更简单:

M = histc(x(:)',[-inf w(:)']);
M = cumsum(M(1:N))';

答案 1 :(得分:5)

至少内圈可以替换为:

M(i)=sum(x<=w(i))

这将提供显着的性能改进。您可以考虑使用arrayfun:

M = arrayfun(@(wi)( sum( x <= wi ) ), w);

arrayfun不太可能在外部for循环上提供大量增益,但可能值得一试。

编辑:我应该注意,无需对wx进行排序,以使此操作正常运行。

编辑:fwiw,我决定做一些实际的性能测试,所以我运行了这个程序:

n = 100000;     m = n;

N = n + m;

x = rand(n, 1);
w = [x; rand(m, 1)];

tic;
M = zeros(N,1);
for i = 1:N
  for j = 1:n
    if x(j) <= w(i)
      M(i) = M(i) + 1;
    end
  end
end
perf = toc;
fprintf(1, 'Original : %4.3f sec\n', perf);

w = sort(w); % presorted, so don't incur time cost;
tic;
M = histc(-x(:)',[-fliplr(w(:)') inf]);
M = cumsum(fliplr(M(1:N)))';
perf = toc;
fprintf(1, 'gnovice : %4.3f sec\n', perf);

tic;
M = zeros(N,1);
for i = 1:N
    M(i)=sum(x<=w(i));
end
perf = toc;
fprintf(1, 'mine/loop : %4.3f sec\n', perf);

tic;
M = arrayfun(@(wi)( sum( x <= wi ) ), w);
perf = toc;
fprintf(1, 'mine/arrayfun : %4.3f sec\n', perf);

并得到n = 1000的这些结果:

Original : 0.086 sec
gnovice : 0.002 sec
mine/loop : 0.042 sec
mine/arrayfun : 0.070 sec

并且对于n = 100000:

Original : too long to tell ( >> 1m )
gnovice : 0.050 sec
mine/loop : too long to tell ( >> 1m )
mine/arrayfun : too long to tell ( >> 1m )

答案 2 :(得分:1)

暂时没有完成MATLAB,但这应该起作用:

  • 使用内置排序算法向上排序x。

  • 使用具有漫游索引的循环仅在x(j)

    上迭代一次
    j = 1;
    for i = 1:N
      while j <= n && x(j) <= w(i)
        M(i) = M(i) + 1;
        j = j+1;
      end
    end
    
  • 最后累积总和

    for j =2:n
      M(j) = M(j-1) + M(j)
    end
    

答案 3 :(得分:1)

试试这个:

M = sum( bsxfun(@le, w', sort(w)) , 2 )

答案 4 :(得分:0)

我没有在我面前使用Matlab所以我无法确认这是100%的功能,但你可能想尝试类似的东西:

for i = 1:N
    M(i) = arrayfun(@(ary,val)length(find(ary <= val)), x, w(i))
end