import numpy, pickle
from sklearn.neighbors import NearestNeighbors, DistanceMetric
from sklearn.externals import joblib
def dist(x,y):
return numpy.sqrt(numpy.sum(numpy.power(x-y, 2)))
X = numpy.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
nbrs = NearestNeighbors(algorithm='ball_tree', metric='pyfunc', metric_params={'func':dist}).fit(X)
joblib.dump(nbrs, 'custom_metric.pkl')
a = numpy.array([1,2])
nbrs_from_file = joblib.load('custom_metric.pkl')
nbrs_from_file.kneighbors(a)
$ python a.py
Traceback (most recent call last):
File "a.py", line 15, in <module>
nbrs_from_file.kneighbors(a)
File "/usr/local/lib/python2.7/dist-packages/sklearn/neighbors/base.py", line 332, in kneighbors
return_distance=return_distance)
File "binary_tree.pxi", line 1346, in sklearn.neighbors.ball_tree.BinaryTree.query (sklearn/neighbors/ball_tree.c:10326)
File "ball_tree.pyx", line 126, in sklearn.neighbors.ball_tree.min_rdist (sklearn/neighbors/ball_tree.c:19132)
File "ball_tree.pyx", line 96, in sklearn.neighbors.ball_tree.min_dist (sklearn/neighbors/ball_tree.c:18893)
File "binary_tree.pxi", line 1167, in sklearn.neighbors.ball_tree.BinaryTree.dist (sklearn/neighbors/ball_tree.c:9363)
File "dist_metrics.pyx", line 1093, in sklearn.neighbors.dist_metrics.PyFuncDistance.dist (sklearn/neighbors/dist_metrics.c:9524)
TypeError: argument after ** must be a mapping, not NoneType