在对以下load data
函数进行概要分析之后,我意识到以下几行是主要瓶颈:
dist_1 = dist[random_labels, :][:, random_labels]
dist_2 = dist[other_random_labels, :][:, other_random_labels]
其中dist
的大小为6000,6000
,随机标签的长度为5000
。
我正在尝试使用np.take
但
np.take(dist_1,[random_labels,random_labels]) == dist_1[random_labels, :][:, random_labels]
是False
。
np.take(dist_1,[random_labels,random_labels])
的尺寸为(2,5000)
有没有一种有效的方式来做到这一点?
编辑:这是我最近的:
dist_1 = np.take(np.take(dist, random_labels, axis=0), random_labels, axis=1)