更有效的方式来获得最近的中心

时间:2015-06-04 12:45:35

标签: python numpy optimization

我的数据对象是以下的实例:

class data_instance:
    def __init__(self, data, tlabel):
        self.data = data # 1xd numpy array
        self.true_label = tlabel # integer {1,-1}

到目前为止,在代码中,我有一个名为data_history的列表,其中包含data_istance和一组centers(带有形状(k,d)的numpy数组)。

对于给定的data_instance new_data,我想:

  • 1 /从new_data(距离欧倍德距离)到centers最近的中心,将其称为Nearest_center

  • 2 /迭代低谷data_history和:

    • 2.1 /选择最近中心为Nearest_center(1 /的结果)的元素到名为neighbors的列表中。
    • 2.2 /获取neighbors中对象的标签。

贝娄是我的代码,但工作速度慢,我正在寻找更高效的东西。

我的代码

1 /

def getNearestCenter(data,centers):

    if centers.shape != (1,2):
        dist_ = np.sqrt(np.sum(np.power(data-centers,2),axis=1)) # This compute distance between data and all centers

        center = centers[np.argmin(dist_)] # this return center which have the minimum distance from data

    else:
        center=centers[0]
    return center

2 /(优化)

def getLabel(dataPoint, C, history):

    labels = []
    cluster = getNearestCenter(dataPoint.data,C)
    for x in history:
        if  np.all(getNearestCenter(x.data,C) == cluster):
            labels.append(x.true_label)
    return labels

2 个答案:

答案 0 :(得分:1)

找到它:

dist_ = np.argmin(np.sqrt(np.sum(np.power(data[:, None]-C,2),axis=2)),axis=1)

这应该从centers的每个数据点返回data中距离最近的中心的索引。

答案 1 :(得分:1)

您应该使用cdist中的优化scipy.spatial,这比使用numpy计算它更有效,

from scipy.spatial.distance import cdist

dist = cdist(data, C, metric='euclidean')
dist_idx = np.argmin(dist, axis=1)

更优雅的解决方案是使用scipy.spatial.cKDTree(正如评论中@Saullo Castro所指出的那样),对于大型数据集来说可能更快,

from scipy.spatial import cKDTree

tr = cKDTree(C)
dist, dist_idx = tr.query(data, k=1)