在scikit-learn中使用用户定义的k-nn距离度量

时间:2018-04-12 15:55:34

标签: python scikit-learn

我有这段代码:

import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import sklearn.neighbors as ng 

def mydist(x, y):
    return np.sum((x-y)**2)

if __name__ == '__main__':
    nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',metric='mydist')

我正在使用sci-kit学习0.18.1而且我收到此错误

ValueError: Metric 'mydist' not valid for algorithm 'ball_tree'

我也尝试过使用algorithm ='brute',但错误仍然存​​在。

造成这种情况的原因是什么?如何正确使用用户定义的距离指标?

1 个答案:

答案 0 :(得分:4)

以下是ball_tree算法的有效指标列表 - scikit-learn内部检查指定的指标是否在其中:

In [114]: from sklearn.neighbors import BallTree

In [115]: BallTree.valid_metrics
Out[115]:
['euclidean',
 'l2',
 'minkowski',
 'p',
 'manhattan',
 'cityblock',
 'l1',
 'chebyshev',
 'infinity',
 'seuclidean',
 'mahalanobis',
 'wminkowski',
 'hamming',
 'canberra',
 'braycurtis',
 'matching',
 'jaccard',
 'dice',
 'kulsinski',
 'rogerstanimoto',
 'russellrao',
 'sokalmichener',
 'sokalsneath',
 'haversine',
 'pyfunc']       # <--- NOTE

因此请尝试指定metric='pyfunc'metric_params={"func":mydist}

knn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',
                              metric='pyfunc', metric_params={"func":mydist})