我正在使用矩阵来计算图中节点之间的关系组合(细节无关紧要)。
我有一个N * N邻接矩阵,行和列对应于节点。因此,位置[5,7]是节点5与节点7的相乘次数。[7,5]也是如此。位置[3,3]是节点3出现的次数,因此它总共出现了多少次。
在每个循环中,我必须缩小矩阵。我取一个大小为n的向量,然后用该向量减去矩阵对角线。所以,我减少了每个节点的总数:就像我矩阵中的[1,1]和[2,2]和[3,3]等。
希望到目前为止,我说得通。这是这个问题。
这时,我已经修改了矩阵的对角线。现在,我要修改每个位置[i,j]
i != j and matrix[i,j] != 0
我想要对其进行修改,以便:
matrix[i,j] = min(matrix[i,i],matrix[j][j])
现在,我当然可以遍历每个索引对(i,j)并执行我在上面编写的内容。但这很慢。我希望有一些聪明的数学或小技巧可以使速度更快。
谢谢!
答案 0 :(得分:2)
首先,在进行任何优化之前,您应该进行概要分析:试图巧妙地处理在程序整个生命周期中仅花费数十毫秒的事情,或者仅占总运行时间的一小部分是没有意义的。
也就是说,您可以利用广播来矢量化最低限度的获取:
def slow(arr):
out = arr.copy()
for (i, j), x in np.ndenumerate(arr):
if i != j and arr[i,j] != 0:
out[i,j] = min(arr[i, i], arr[j, j])
return out
def fast(arr):
diag = arr.diagonal()
mins = np.minimum(diag, diag[:, None])
out = np.where(arr != 0, mins, arr)
out[np.diag_indices_from(arr)] = diag
return out
这给了我
In [61]: a = np.random.randint(0, 10, (100, 100))
In [62]: (slow(a) == fast(a)).all()
Out[62]: True
In [63]: %timeit slow(a)
11.9 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [64]: %timeit fast(a)
62.8 µs ± 916 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)