检查大型numpy矩阵中的三角形不等式

时间:2019-01-05 20:00:49

标签: python algorithm numpy distance

我有一个非负浮点数的对称NumPy矩阵D。第i行和第j列中的数字表示对象ij之间的距离,无论它们是什么。矩阵很大(〜10,000行/列)。我想检查矩阵中的所有距离是否都服从三角形不等式,即:对于所有D[i,j]<=D[i,k]+D[k,j]ij的{​​{1}}。

使用三重嵌套循环可以非常有效地解决问题。但是,有没有更快的矢量化解决方案?

1 个答案:

答案 0 :(得分:2)

您当然可以使用(未​​试用)足够容易地向量化最内部的循环:

for i in range(N):
    for j in range(i):
        assert all(D[i,j] <= D[i,:] + D[:,j])

对于双重矢量化,您可以遍历k(也未经测试):

for k in range(N):
    row = D[k,:].reshape(1, N)
    col = D[:,k].reshape(N, 1)
    assert all(D <= row + col)

({row + col生成与D大小相同的方阵