使用sklearn.neighbors.BallTree的自定义指标会给出错误的输入?

时间:2016-08-07 16:07:38

标签: python scikit-learn nearest-neighbor metric

我试图在sklearn.neighbors.BallTree中使用自定义指标,但是当它调用我的指标时,输入看起来不正确。如果我使用具有相同自定义指标的scipy.spatial.distance.pdist,它将按预期工作。但是,如果我尝试实例化BallTree,当我尝试重塑输入时会引发异常。如果我查看实际输入,则形状和值看起来不正确。

import numpy as np
import scipy.spatial.distance as spdist
import sklearn.neighbors.ball_tree as ball_tree


# custom metric
def minimum_average_direct_flip(x, y):
    x = np.reshape(x, (-1, 3))
    y = np.reshape(y, (-1, 3))
    direct = np.mean(np.sqrt(np.sum(np.square(x - y), axis=1)))
    flipped = np.mean(np.sqrt(np.sum(np.square(np.flipud(x) - y), axis=1)))
    return min(direct, flipped)

# create an X to test
X = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19], [21, 22, 23, 24, 25, 26, 27, 28, 29]])

# works as expected
distances = spdist.pdist(X, metric=minimum_average_direct_flip)

# outputs: [ 17.32050808  34.64101615  17.32050808]
print distances

# raises exception, inputs to minimum_average_direct_flip look wrong
# Traceback (most recent call last):
#   File ".../test_script.py", line 23, in <module>
#     ball_tree.BallTree(X, metric=minimum_average_direct_flip)
#   File "sklearn/neighbors/binary_tree.pxi", line 1059, in sklearn.neighbors.ball_tree.BinaryTree.__init__ (sklearn\neighbors\ball_tree.c:8381)
#   File "sklearn/neighbors/dist_metrics.pyx", line 262, in sklearn.neighbors.dist_metrics.DistanceMetric.get_metric (sklearn\neighbors\dist_metrics.c:4032)
#   File "sklearn/neighbors/dist_metrics.pyx", line 1091, in sklearn.neighbors.dist_metrics.PyFuncDistance.__init__ (sklearn\neighbors\dist_metrics.c:10586)
#   File "C:/Users/danrs/Documents/neuro_atlas/test_script.py", line 8, in minimum_average_direct_flip
#     x = np.reshape(x, (-1, 3))
#   File "C:\Anaconda2\lib\site-packages\numpy\core\fromnumeric.py", line 225, in reshape
#     return reshape(newshape, order=order)
# ValueError: total size of new array must be unchanged
ball_tree.BallTree(X, metric=minimum_average_direct_flip)

在第一次从BallTree代码调用minimum_average_direct_flip时,输入为:

x = [ 0.4238394   0.55205233  0.04699435  0.19542642  0.20331665  0.44594837 0.35634537  0.8200018   0.28598294  0.34236847]
y = [ 0.4238394   0.55205233  0.04699435  0.19542642  0.20331665  0.44594837 0.35634537  0.8200018   0.28598294  0.34236847]

这些看起来完全错误。我在调用这个方式时做错了什么,或者这是sklearn中的一个错误?

1 个答案:

答案 0 :(得分:0)

这似乎是一个已知问题: https://github.com/scikit-learn/scikit-learn/issues/6287

他们做了某种有问题的验证步骤。作为一种解决方法,我想我可以添加一个输入大小的检查,但是因为问题注意到这是不可取的,因为我不能自己进行实际的验证检查。