我有这段代码:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn.neighbors as ng
def mydist(x, y):
return np.sum((x-y)**2)
if __name__ == '__main__':
nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',metric='mydist')
我正在使用sci-kit学习0.18.1而且我收到此错误
ValueError: Metric 'mydist' not valid for algorithm 'ball_tree'
我也尝试过使用algorithm ='brute',但错误仍然存在。
造成这种情况的原因是什么?如何正确使用用户定义的距离指标?
答案 0 :(得分:4)
以下是ball_tree
算法的有效指标列表 - scikit-learn
内部检查指定的指标是否在其中:
In [114]: from sklearn.neighbors import BallTree
In [115]: BallTree.valid_metrics
Out[115]:
['euclidean',
'l2',
'minkowski',
'p',
'manhattan',
'cityblock',
'l1',
'chebyshev',
'infinity',
'seuclidean',
'mahalanobis',
'wminkowski',
'hamming',
'canberra',
'braycurtis',
'matching',
'jaccard',
'dice',
'kulsinski',
'rogerstanimoto',
'russellrao',
'sokalmichener',
'sokalsneath',
'haversine',
'pyfunc'] # <--- NOTE
因此请尝试指定metric='pyfunc'
和metric_params={"func":mydist}
:
knn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',
metric='pyfunc', metric_params={"func":mydist})