我的数据对象是以下的实例:
class data_instance:
def __init__(self, data, tlabel):
self.data = data # 1xd numpy array
self.true_label = tlabel # integer {1,-1}
到目前为止,在代码中,我有一个名为data_history
的列表,其中包含data_istance
和一组centers
(带有形状(k,d)的numpy数组)。
对于给定的data_instance new_data
,我想:
1 /从new_data
(距离欧倍德距离)到centers
最近的中心,将其称为Nearest_center
。
2 /迭代低谷data_history
和:
Nearest_center
(1 /的结果)的元素到名为neighbors
的列表中。neighbors
中对象的标签。贝娄是我的代码,但工作速度慢,我正在寻找更高效的东西。
我的代码
1 /
def getNearestCenter(data,centers):
if centers.shape != (1,2):
dist_ = np.sqrt(np.sum(np.power(data-centers,2),axis=1)) # This compute distance between data and all centers
center = centers[np.argmin(dist_)] # this return center which have the minimum distance from data
else:
center=centers[0]
return center
2 /(优化)
def getLabel(dataPoint, C, history):
labels = []
cluster = getNearestCenter(dataPoint.data,C)
for x in history:
if np.all(getNearestCenter(x.data,C) == cluster):
labels.append(x.true_label)
return labels
答案 0 :(得分:1)
找到它:
dist_ = np.argmin(np.sqrt(np.sum(np.power(data[:, None]-C,2),axis=2)),axis=1)
这应该从centers
的每个数据点返回data
中距离最近的中心的索引。
答案 1 :(得分:1)
您应该使用cdist
中的优化scipy.spatial
,这比使用numpy计算它更有效,
from scipy.spatial.distance import cdist
dist = cdist(data, C, metric='euclidean')
dist_idx = np.argmin(dist, axis=1)
更优雅的解决方案是使用scipy.spatial.cKDTree
(正如评论中@Saullo Castro所指出的那样),对于大型数据集来说可能更快,
from scipy.spatial import cKDTree
tr = cKDTree(C)
dist, dist_idx = tr.query(data, k=1)