sklearn BallTree更改传递给指标的数据

时间:2018-04-03 20:49:24

标签: python scikit-learn

我正在使用sklearn.neighbors.BallTree的自定义指标函数,但我遇到了问题,因为BallTree似乎在将数据传递给我的指标函数之前更改了数据。这是一个展示这个的例子:

from sklearn.neighbors import BallTree
import numpy as np

np.random.seed(0)
data = np.random.randint(0, 20, size=(2, 3))
def metric(x, y):
    print('Data passed to metric')
    print(x)
    print(y)
    return 1

print('Original data')
print(data)
BallTree(data, metric=metric)

这给了我

Original data
[[12 15  0]
 [ 3  3  7]]
Data passed to metric
[7.5 9.  3.5]
[12. 15.  0.]
Data passed to metric
[7.5 9.  3.5]
[3. 3. 7.]

在将数据传递给BallTree之前,metric做了哪些预处理?有没有办法把它关掉?它甚至似乎改变了对metric ...

的调用之间的数据

(我的实际用例 - 我使用Levenstein距离作为我的指标并使用字符串。但是,由于我不能直接传入字符串,我将每个字符转换为预定义的标记并传入一个数组由于数据被修改,我不再能够撤消编码以在我的度量函数中返回字符串,这样我才能正确计算出Levenstein距离。如果你有更好的解决方案来寻找最近的邻居字符串而不是数字数据,我也很高兴听到这一点。

1 个答案:

答案 0 :(得分:1)

它没有。

BallTree对象不会更改您的数据。

  1. 它会创建您的数据副本,因为:
  2.   

    注意:如果X是C连续的双精度数组,则数据不会   复制。否则,将制作内部副本。

    1. 它计算对象和树节点边界之间的距离。如您所见,您可以使用get_arrays函数来获取内部数组,并通过检查源代码,您会发现边界是[7.5, 9. , 3.5],这是它正在比较对象的边界
    2. Source

      def get_arrays(self):
              return (self.data_arr, self.idx_array_arr,
                      self.node_data_arr, self.node_bounds_arr)
      

      输出:

      bt.get_arrays()                                                                                                                                                                                           
      Out[x]:                                                                                                                                                                                                           
      (array([[12., 15.,  0.],                                                                                                                                                                                           
              [ 3.,  3.,  7.]]), array([0, 1]), array([(0, 2, 1, 1.)],                                                                                                                                                   
             dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]), array([[[7.5, 9. , 3.5]]])) 
      

      因此,您的指标将应用于数据和节点,而不仅仅是您自己的数据,并且节点与您的数据不同。您可以尝试单词嵌入,这样您就可以计算距离,而无需解码数据。不确定你要做什么,但也许基于树的模型不是你的用例的最佳方式。