无法使用自定义距离函数序列化和反序列化SciKit Learn模型

时间:2015-03-08 07:10:33

标签: python serialization scikit-learn

import numpy, pickle
from sklearn.neighbors import NearestNeighbors, DistanceMetric
from sklearn.externals import joblib

def dist(x,y):
    return numpy.sqrt(numpy.sum(numpy.power(x-y, 2)))

X = numpy.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
nbrs = NearestNeighbors(algorithm='ball_tree', metric='pyfunc', metric_params={'func':dist}).fit(X)

joblib.dump(nbrs, 'custom_metric.pkl')

a = numpy.array([1,2])
nbrs_from_file = joblib.load('custom_metric.pkl')
nbrs_from_file.kneighbors(a)

$ python a.py

Traceback (most recent call last):
  File "a.py", line 15, in <module>
    nbrs_from_file.kneighbors(a)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/neighbors/base.py", line 332, in kneighbors
    return_distance=return_distance)
  File "binary_tree.pxi", line 1346, in sklearn.neighbors.ball_tree.BinaryTree.query (sklearn/neighbors/ball_tree.c:10326)
  File "ball_tree.pyx", line 126, in sklearn.neighbors.ball_tree.min_rdist (sklearn/neighbors/ball_tree.c:19132)
  File "ball_tree.pyx", line 96, in sklearn.neighbors.ball_tree.min_dist (sklearn/neighbors/ball_tree.c:18893)
  File "binary_tree.pxi", line 1167, in sklearn.neighbors.ball_tree.BinaryTree.dist (sklearn/neighbors/ball_tree.c:9363)
  File "dist_metrics.pyx", line 1093, in sklearn.neighbors.dist_metrics.PyFuncDistance.dist (sklearn/neighbors/dist_metrics.c:9524)
TypeError: argument after ** must be a mapping, not NoneType

0 个答案:

没有答案