如何从kNeighborsClassifier中找到前n个匹配项?

时间:2018-07-17 11:57:26

标签: python machine-learning scikit-learn knn nearest-neighbor

我正在尝试从一组带有标签的样本矢量中搜索一个矢量。我需要找到最佳的n匹配项。我为此使用kNeighborsClassifier

nbrs = KNeighborsClassifier(n_neighbors=2, algorithm='ball_tree', metric='euclidean').fit(train_data_array, train_label)
yp = nbrs.predict(xt)

但是问题是它只返回前1个结果。我认为,基于欧几里得距离,我可以获得前n个匹配项,但我不确定如何提取该信息。

1 个答案:

答案 0 :(得分:3)

有一个kneighbors() method in KNeighborsClassifier可以使用。

它将返回训练数据的索引(您在fit()中使用的索引)以及最接近您在其中提供的点的距离。

示例:

from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
iris = datasets.load_iris()
X = iris.data
y = iris.target

clf = KNeighborsClassifier()
clf.fit(X, y)

# here I am taking a single point only
distances, indices = clf.kneighbors(X[[0]],  n_neighbors=2)

print(distances, indices)

#Output: array([[0., 0.]]), array([[17,  0]])

第一个输出是距离,第二个输出是X的最接近X[[0]]的索引