我正在使用sklearn.neighbors.BallTree
的自定义指标函数,但我遇到了问题,因为BallTree
似乎在将数据传递给我的指标函数之前更改了数据。这是一个展示这个的例子:
from sklearn.neighbors import BallTree
import numpy as np
np.random.seed(0)
data = np.random.randint(0, 20, size=(2, 3))
def metric(x, y):
print('Data passed to metric')
print(x)
print(y)
return 1
print('Original data')
print(data)
BallTree(data, metric=metric)
这给了我
Original data
[[12 15 0]
[ 3 3 7]]
Data passed to metric
[7.5 9. 3.5]
[12. 15. 0.]
Data passed to metric
[7.5 9. 3.5]
[3. 3. 7.]
在将数据传递给BallTree
之前,metric
做了哪些预处理?有没有办法把它关掉?它甚至似乎改变了对metric
...
(我的实际用例 - 我使用Levenstein距离作为我的指标并使用字符串。但是,由于我不能直接传入字符串,我将每个字符转换为预定义的标记并传入一个数组由于数据被修改,我不再能够撤消编码以在我的度量函数中返回字符串,这样我才能正确计算出Levenstein距离。如果你有更好的解决方案来寻找最近的邻居字符串而不是数字数据,我也很高兴听到这一点。
答案 0 :(得分:1)
它没有。
BallTree
对象不会更改您的数据。
注意:如果X是C连续的双精度数组,则数据不会 复制。否则,将制作内部副本。
get_arrays
函数来获取内部数组,并通过检查源代码,您会发现边界是[7.5, 9. , 3.5]
,这是它正在比较对象的边界def get_arrays(self):
return (self.data_arr, self.idx_array_arr,
self.node_data_arr, self.node_bounds_arr)
输出:
bt.get_arrays()
Out[x]:
(array([[12., 15., 0.],
[ 3., 3., 7.]]), array([0, 1]), array([(0, 2, 1, 1.)],
dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]), array([[[7.5, 9. , 3.5]]]))
因此,您的指标将应用于数据和节点,而不仅仅是您自己的数据,并且节点与您的数据不同。您可以尝试单词嵌入,这样您就可以计算距离,而无需解码数据。不确定你要做什么,但也许基于树的模型不是你的用例的最佳方式。