如何在scikit-learn中为最近邻居使用用户定义的度量标准?

时间:2016-02-20 02:27:28

标签: scikit-learn distance metrics nearest-neighbor

我正在使用scikit-learn 0.18.dev0。我知道在input pipelines之前已经问过同样的问题。我尝试了那里提供的答案,我收到以下错误

{{1}}

错误讯息 {{1}}

看起来此选项已被删除。如何在{{1}}中使用用户定义的矩阵?

1 个答案:

答案 0 :(得分:5)

正确的关键字是metric

import numpy as np
from sklearn.neighbors import NearestNeighbors

def mydist(x, y):
    return np.sum((x-y)**2)

nn = NearestNeighbors(n_neighbors=4, algorithm='ball_tree', metric=myfunc)

X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])
nn.fit(X)

在开发版本的文档字符串中也提到了这一点:https://github.com/scikit-learn/scikit-learn/blob/86b1ba72771718acbd1e07fbdc5caaf65ae65440/sklearn/neighbors/unsupervised.py#L48