如何使三角矩阵更有效地使用np.where?

时间:2018-06-18 10:07:08

标签: python numpy scipy itertools

我得到了这段代码,其中距离是一个下三角矩阵,定义如下:

Map

我的问题是np.where用大矩阵(例如2000 * 100)很慢 如何通过改进np.where或改变算法来加速这段代码?

编辑:正如MaxU所指出的,这里最好的优化不是生成平方矩阵并使用迭代器。

示例:

distance = np.tril(scipy.spatial.distance.cdist(points, points))  
def make_them_touch(distance):
    """
    Return the every distance where two points touched each other. See example below.
    """
    thresholds = np.unique(distance)[1:] # to avoid 0 at the beginning, not taking a lot of time at all
    result = dict()
    for t in thresholds:
            x, y = np.where(distance == t)
            result[t] = [i for i in zip(x,y)]
    return result

1 个答案:

答案 0 :(得分:1)

UPDATE1:这里是一个三角形距离矩阵的片段(由于距离矩阵始终是对称的,它不应该真正重要):

from itertools import combinations

res = {tup[0]:tup[1] for tup in zip(pdist(points), list(combinations(range(len(points)), 2)))}

结果:

In [111]: res
Out[111]:
{1.4142135623730951: (0, 1),
 4.69041575982343: (0, 2),
 4.898979485566356: (1, 2)}

UPDATE2:此版本将支持距离重复:

In [164]: import pandas as pd

首先我们构建一个Pandas.Series:

In [165]: s = pd.Series(list(combinations(range(len(points)), 2)), index=pdist(points))

In [166]: s
Out[166]:
2.0     (0, 1)
6.0     (0, 2)
12.0    (0, 3)
4.0     (1, 2)
10.0    (1, 3)
6.0     (2, 3)
dtype: object

现在我们可以按索引进行分组并生成坐标列表:

In [167]: s.groupby(s.index).apply(list)
Out[167]:
2.0             [(0, 1)]
4.0             [(1, 2)]
6.0     [(0, 2), (2, 3)]
10.0            [(1, 3)]
12.0            [(0, 3)]
dtype: object

PS这里的主要想法是,如果你打算在之后将它弄平并且去除重复,你就不应该建立方形距离矩阵。