我有一个带有嵌套for循环的python函数,它被调用了数千次,而且太慢了。从我在网上看到的,应该有一种方法来通过numpy矢量化来优化它,以便迭代在更快的C代码而不是python中完成。但是,我以前从未与numpy合作过,我无法弄明白。
功能如下。第一个参数是二维数组(列表列表)。第二个参数是要检查的2D数组的行列表。第三个参数是要检查的2D数组的列列表(请注意,行数不等于cols数)。第四个参数是用于比较2D阵列的元素的值。我试图返回一个列表,列表中每列包含一个列表,其中所有行索引都对应于等于val的元素。
def filter_indices(my_2d_arr, rows, cols, val):
result_indices = []
for c in cols:
col_indices = []
for idx in rows:
if my_2d_arr[idx][c] == val:
col_indices.append(idx)
result_indices.append(col_indices)
return result_indices
就像我说的那样,这太慢了,我很困惑我怎么能把它矢量化这个很难看。任何指针/指导都会很棒。
@ B.M。感谢您的回答。我将自己的解决方案与我的其余代码分开运行,并将其与之前的函数进行比较,而不是numpy。就像你说的那样,我的原始功能确实比numpy工作得快得多。但是,当它作为我的代码的一部分运行时,我的解决方案实际上由于某种原因而变慢。我确实必须为你的函数添加一些内容并修改我现有的一些代码以使它们兼容,但是我被抛弃了,因为cProfile显示我原来的filter_indices函数更快,numpy版本更快而不是新的numpy。我不知道numpy filter_indices如何花费这么长时间,考虑到与我的其余代码分开运行它会更快。
这是我原来的filter_indices没有numpy:
def filter_indices_orig(a, data_indices, feature_set, val):
result_indices = []
for feature_no in feature_set:
feature_indices = []
for idx in data_indices:
if a[idx][feature_no] == val:
feature_indices.append(idx)
result_indices.append(feature_indices)
return result_indices
这是我稍微修改过的带有numpy的filter_indices:
def filter_indices(a, data_indices, feature_set, val):
result_indices = {}
sub = a[np.meshgrid(data_indices, feature_set, indexing='ij')]
r, c = (sub == val).nonzero()
rs = np.take(data_indices, r)
cs = np.take(feature_set, c)
coords = zip(rs, cs)
for r, c in coords:
feat_indices = result_indices.get(c, [])
feat_indices.append(r)
result_indices[c] = feat_indices
return result_indices
当我只搜索几列时,我发现numpy解决方案速度较慢,但是当我查找大量列时速度更快。不幸的是,即使专门使用我原来的非numpy解决方案,当搜索大量列时使用numpy解决方案搜索几列仍然比我原来的解决方案慢,我不明白。
答案 0 :(得分:0)
这是一个函数,它返回所选子数组中值为val
的像素的2个数组,行索引和cols索引:
def filter_indices_numpy(a,rows,cols,val):
sub=a[meshgrid(rows,cols,indexing='ij')]
r,c = (sub==val).nonzero()
return take(rows,r),take(cols,c)
示例:
a=randint(0,3,(5,5))
#array([[0, 1, 0, 2, 2],
# [0, 0, 2, 0, 0],
# [2, 1, 1, 0, 0],
# [1, 0, 0, 1, 2],
# [2, 1, 0, 0, 0]])
filter_indices_numpy(a,[1,2,3],[1,2,3],0)
#(array([1, 1, 2, 3, 3]), array([1, 3, 3, 1, 2]))
一些解释:
meshgrid(rows,cols,indexing='ij')
是所选行和列的索引。
sub
是子数组。 r,c = (sub==val).nonzero()
是子数组中值为val
的索引。 take(rows,r),take(cols,c)
翻译数组a
中的索引。
测试:a=randint(0,200,(1000,1000));rows=cols=arange(100)
In [4]: %timeit filter_indices(a,rows,cols,0)
10 loops, best of 3: 23.1 ms per loop
In [5]: %timeit filter_indices_numpy(a,rows,cols,0)
1000 loops, best of 3: 933 µs per loop
它快了约25倍。