我有以下代码遍历名为“ m”的2d numpy数组。它工作非常慢。如何使用numpy函数转换此代码,从而避免使用for循环?
pairs = []
for i in range(size):
for j in range(size):
if(i >= j):
continue
if(m[i][j] + m[j][i] >= 0.75):
pairs.append([i, j, m[i][j] + m[j][i]])
答案 0 :(得分:4)
一种优化代码的方法是避免比较if (i >= j)
。要仅遍历数组的下三角而不进行比较,必须使内部循环以最外部循环的i
的值开始。这样,您就可以避免进行size x size
if
比较。
import numpy as np
size = 5000
m = np.random.rand(size, size)
pairs = []
for i in range(size):
for j in range(i , size):
if(m[i][j] + m[j][i] >= 0.75):
pairs.append([i, j, m[i][j] + m[j][i]])
答案 1 :(得分:4)
您可以使用NumPy使用向量化方法。这个想法是:
m
,然后创建与m+m.T
等效的m[i][j] + m[j][i]
,其中m.T
是矩阵转置并命名为summ
np.triu
(summ)
返回矩阵的上三角部分(这等效于在代码中使用continue
来忽略下三角部分)。这样可以避免在代码中使用显式if(i >= j):
。在这里,您必须使用k=1
来排除对角线元素。默认情况下,k=0
也包括对角线元素。 np.argwhere
获得点的索引,其中总和m+m.T
等于0.75 可验证的示例(使用小的3x3随机数据集)
import numpy as np
np.random.seed(0)
m = np.random.rand(3,3)
summ = m + m.T
index = np.argwhere(np.triu(summ, k=1)>=0.75)
pairs = [(x,y, summ[x,y]) for x,y in index]
print (pairs)
# # [(0, 1, 1.2600725493693163), (0, 2, 1.0403505873343364), (1, 2, 1.537667113848736)]
进一步的性能改进
我刚刚想出了一种更快的方法来生成最终的pairs
列表,从而避免显式的for循环
pairs = list(zip(index[:, 0], index[:, 1], summ[index[:,0], index[:,1]]))