我得到了这段代码,其中距离是一个下三角矩阵,定义如下:
Map
我的问题是np.where用大矩阵(例如2000 * 100)很慢 如何通过改进np.where或改变算法来加速这段代码?
编辑:正如MaxU所指出的,这里最好的优化不是生成平方矩阵并使用迭代器。
distance = np.tril(scipy.spatial.distance.cdist(points, points))
def make_them_touch(distance):
"""
Return the every distance where two points touched each other. See example below.
"""
thresholds = np.unique(distance)[1:] # to avoid 0 at the beginning, not taking a lot of time at all
result = dict()
for t in thresholds:
x, y = np.where(distance == t)
result[t] = [i for i in zip(x,y)]
return result
答案 0 :(得分:1)
UPDATE1:这里是一个上三角形距离矩阵的片段(由于距离矩阵始终是对称的,它不应该真正重要):
from itertools import combinations
res = {tup[0]:tup[1] for tup in zip(pdist(points), list(combinations(range(len(points)), 2)))}
结果:
In [111]: res
Out[111]:
{1.4142135623730951: (0, 1),
4.69041575982343: (0, 2),
4.898979485566356: (1, 2)}
UPDATE2:此版本将支持距离重复:
In [164]: import pandas as pd
首先我们构建一个Pandas.Series:
In [165]: s = pd.Series(list(combinations(range(len(points)), 2)), index=pdist(points))
In [166]: s
Out[166]:
2.0 (0, 1)
6.0 (0, 2)
12.0 (0, 3)
4.0 (1, 2)
10.0 (1, 3)
6.0 (2, 3)
dtype: object
现在我们可以按索引进行分组并生成坐标列表:
In [167]: s.groupby(s.index).apply(list)
Out[167]:
2.0 [(0, 1)]
4.0 [(1, 2)]
6.0 [(0, 2), (2, 3)]
10.0 [(1, 3)]
12.0 [(0, 3)]
dtype: object
PS这里的主要想法是,如果你打算在之后将它弄平并且去除重复,你就不应该建立方形距离矩阵。