我试图在sklearn.neighbors.BallTree中使用自定义指标,但是当它调用我的指标时,输入看起来不正确。如果我使用具有相同自定义指标的scipy.spatial.distance.pdist,它将按预期工作。但是,如果我尝试实例化BallTree,当我尝试重塑输入时会引发异常。如果我查看实际输入,则形状和值看起来不正确。
import numpy as np
import scipy.spatial.distance as spdist
import sklearn.neighbors.ball_tree as ball_tree
# custom metric
def minimum_average_direct_flip(x, y):
x = np.reshape(x, (-1, 3))
y = np.reshape(y, (-1, 3))
direct = np.mean(np.sqrt(np.sum(np.square(x - y), axis=1)))
flipped = np.mean(np.sqrt(np.sum(np.square(np.flipud(x) - y), axis=1)))
return min(direct, flipped)
# create an X to test
X = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19], [21, 22, 23, 24, 25, 26, 27, 28, 29]])
# works as expected
distances = spdist.pdist(X, metric=minimum_average_direct_flip)
# outputs: [ 17.32050808 34.64101615 17.32050808]
print distances
# raises exception, inputs to minimum_average_direct_flip look wrong
# Traceback (most recent call last):
# File ".../test_script.py", line 23, in <module>
# ball_tree.BallTree(X, metric=minimum_average_direct_flip)
# File "sklearn/neighbors/binary_tree.pxi", line 1059, in sklearn.neighbors.ball_tree.BinaryTree.__init__ (sklearn\neighbors\ball_tree.c:8381)
# File "sklearn/neighbors/dist_metrics.pyx", line 262, in sklearn.neighbors.dist_metrics.DistanceMetric.get_metric (sklearn\neighbors\dist_metrics.c:4032)
# File "sklearn/neighbors/dist_metrics.pyx", line 1091, in sklearn.neighbors.dist_metrics.PyFuncDistance.__init__ (sklearn\neighbors\dist_metrics.c:10586)
# File "C:/Users/danrs/Documents/neuro_atlas/test_script.py", line 8, in minimum_average_direct_flip
# x = np.reshape(x, (-1, 3))
# File "C:\Anaconda2\lib\site-packages\numpy\core\fromnumeric.py", line 225, in reshape
# return reshape(newshape, order=order)
# ValueError: total size of new array must be unchanged
ball_tree.BallTree(X, metric=minimum_average_direct_flip)
在第一次从BallTree代码调用minimum_average_direct_flip时,输入为:
x = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847]
y = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847]
这些看起来完全错误。我在调用这个方式时做错了什么,或者这是sklearn中的一个错误?
答案 0 :(得分:0)
这似乎是一个已知问题: https://github.com/scikit-learn/scikit-learn/issues/6287
他们做了某种有问题的验证步骤。作为一种解决方法,我想我可以添加一个输入大小的检查,但是因为问题注意到这是不可取的,因为我不能自己进行实际的验证检查。