我有以下代码输出array1
的值小于或等于array2
的每个元素的值。这两个数组的长度不同。由于数组很大(~500,000
元素),因此for循环非常慢。仅供参考,两个阵列总是按升序排列。
任何帮助使这个矢量操作并加快它的速度将不胜感激。
我正在考虑interp1()
采用'最近'选项的某种多步骤流程。然后找到相应的outArray
大于array2
的位置,然后以某种方式固定点......但我认为必须有更好的方法。
array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
for a =1:numel(array2)
outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end
返回:
outArray =
5 5 15 24
答案 0 :(得分:3)
这是一种可能的矢量化:
[~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
outArray = array1(idx);
在最近的版本中,由于JIT编译,MATLAB已经非常擅长执行良好的旧非矢量化循环。
下面是一些类似于你的代码,它利用了两个数组的排序事实(因此如果pos(a) = find(array1<=array2(a), 1, 'last')
,那么我们可以保证在下一次迭代中计算的pos(a+1)
不小于上一个pos(a)
)
pos = 1;
idx = zeros(size(array2));
for a=1:numel(array2)
while pos <= numel(array1) && array1(pos) <= array2(a)
pos = pos + 1;
end
idx(a) = pos-1;
end
%idx(idx==0) = []; %# in case min(array2) < min(array1)
outArray = array1(idx);
注意:当array2
的最小值小于array1
的最小值时(即find(array1<=array2(a))
为空时),注释行处理大小写
我对目前发布的所有方法进行了比较,这确实是最快的方法。对于长度为N = 5000的向量,定时(使用TIMEIT函数执行)是:
0.097398 # your code
0.39127 # my first vectorized code
0.00043361 # my new code above
0.0016276 # Mohsen Nosratinia's code
以下是N = 500000的时间安排:
(? too-long) # your code
(out-of-mem) # my first vectorized code
0.051197 # my new code above
0.25206 # Mohsen Nosratinia's code
..从你报告的最初10分钟到0.05秒,这是一个相当不错的改进!
如果您想要重现结果,请输入以下测试代码:
function [t,v] = test_array_find()
%array2 = [5 6 18 25];
%array1 = [1 5 9 15 22 24 31];
N = 5000;
array1 = sort(randi([100 1e6], [1 N]));
array2 = sort(randi([min(array1) 1e6], [1 N]));
f = {...
@() func1(array1,array2); %# Aero Engy
@() func2(array1,array2); %# Amro
@() func3(array1,array2); %# Amro
@() func4(array1,array2); %# Mohsen Nosratinia
};
t = cellfun(@timeit, f);
v = cellfun(@feval, f, 'UniformOutput',false);
assert( isequal(v{:}) )
end
function outArray = func1(array1,array2)
%idx = arrayfun(@(a) find(array1<=a, 1, 'last'), array2);
idx = zeros(size(array2));
for a=1:numel(array2)
idx(a) = find(array1 <= array2(a), 1, 'last');
end
outArray = array1(idx);
end
function outArray = func2(array1,array2)
[~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
outArray = array1(idx);
end
function outArray = func3(array1,array2)
pos = 1;
lastPos = numel(array1);
idx = zeros(size(array2));
for a=1:numel(array2)
while pos <= lastPos && array1(pos) <= array2(a)
pos = pos + 1;
end
idx(a) = pos-1;
end
%idx(idx==0) = []; %# in case min(array2) < min(array1)
outArray = array1(idx);
end
function outArray = func4(array1,array2)
[~,I] = sort([array1 array2]);
a1size = numel(array1);
J = find(I>a1size);
outArray = nan(size(array2));
for k=1:numel(J),
if I(J(k)-1)<=a1size,
outArray(k) = array1(I(J(k)-1));
else
outArray(k) = outArray(k-1);
end
end
end
答案 1 :(得分:2)
速度缓慢的一个原因是,您要将array1
中的所有元素与array2
中的所有元素进行比较,以便它们分别包含M
和N
个元素,复杂性为O(M*N)
。但是,由于数组已经排序,因此有一个线性时间O(M+N)
解决方案
array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);
ks = 1;
while ks <= n2 && array2(ks) < array1(1)
ks = ks + 1;
end
for k2=ks:n2
while k1 < n1 && array2(k2) >= array1(k1+1)
k1 = k1+1;
end
outArray(k2) = array1(k1);
end
这是一个测试用例,用于测量每个方法运行两个长度为500,000的数组所需的时间。
array2 = 1:500000;
array1 = array2-1;
tic
outArray1 = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);
ks = 1;
while ks <= n2 && array2(ks) < array1(1)
ks = ks + 1;
end
for k2=ks:n2
while k1 < n1 && array2(k2) >= array1(k1+1)
k1 = k1+1;
end
outArray1(k2) = array1(k1);
end
toc
tic
outArray2 = nan(size(array2));
for a =1:numel(array2)
outArray2(a) = array1(find(array1 <= array2(a),1,'last'));
end
toc
结果是
Elapsed time is 0.067637 seconds.
Elapsed time is 418.458722 seconds.
答案 2 :(得分:0)
注意: 这是我最初的解决方案,也是Amro答案的基准测试。但是,它比我在其他答案中提供的线性时间解决方案要慢。
速度缓慢的一个原因是您要将array1
中的所有元素与array2
中的所有元素进行比较,因此如果它们包含M
和N
元素,则复杂性为O(M*N)
。但是,您可以将它们连接起来并将它们排序在一起,并获得更快的复杂算法(M+N)*log2(M+N)
。这是一种方法:
array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
[~,I] = sort([array1 array2]);
a1size = numel(array1);
J = find(I>a1size);
outArray = nan(size(array2));
for k=1:numel(J),
if I(J(k)-1)<=a1size,
outArray(k) = array1(I(J(k)-1));
else
outArray(k) = outArray(k-1);
end
end
disp(outArray)
% Test using original code
outArray = nan(size(array2));
for a =1:numel(array2)
outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end
disp(outArray)
连接数组将是
>> [array1 array2]
ans =
1 5 9 15 22 24 31 5 6 18 25
和
>> [B,I] = sort([array1 array2])
B =
1 5 5 6 9 15 18 22 24 25 31
I =
1 2 8 9 3 4 10 5 6 11 7
它显示在排序数组B
中,第一个5
来自连接数组中的第二个位置,第二个位于8个位置,依此类推。因此,要查找array1
中小于array2
中给定元素的最大元素,我们只需要遍历I
中大于array1
大小的所有索引(因此属于array2
)并返回并找到属于array1
的最近索引。 J
包含这些元素在向量I
中的位置:
>> J = find(I>a1size)
J =
3 4 7 10
现在,for循环遍历这些索引并检查I
之前的索引是否恰好在J
引用的每个索引属于array1
之前。如果它属于array1
,则会从array1
检索它,否则会复制为先前索引找到的值。
请注意,如果array2
包含的元素小于array1
中的最小元素,则代码和此代码都会失败。