sklearn BallTree给出了意想不到的结果

时间:2015-11-11 22:53:02

标签: python scikit-learn

我做错了什么?

我正在尝试使用sklearn的BallTree来提出类似的集合,然后针对特定集合中可能缺少的项目生成一些建议。

import random
from sklearn.neighbors import BallTree
import numpy

collections = []  # 10k sample collections of between
                  # 7 and 15 (of a possible 300...) items

for sample in range(0, 10000):  # build sample data
   items = random.sample(range(1, 300), random.randint(7, 15))
   collections.append(items)    

darray = numpy.zeros((len(collections), max(map(len, collections))))  # 10k x 15 matrix

for c_cnt, items in enumerate(collections):  # populate matrix
   for cnt, i in enumerate(sorted(items)):
      darray[C_cnt][cnt] = i

query = BallTree(darray).query(darray[0], k=15)

nearest_neighbors = query[1][0]

# test the results against the first item!

all_sets = [set(darray[0]) & set(darray[item]) for item in nearest_neighbors]
for item in all_sets:
    print item  # intersection of the neighbor

我得到以下结果:

set([0.0, 130.0, 167.0, 290.0, 162.0, 144.0, 17.0, 214.0]) # Nearest neighbor is itself! Awesome!
set([0.0])  # WTF? The second closest item shares only 1 item?
set([0.0, 290.0])
set([0.0, 17.0])
set([0.0, 130.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0])
set([0.0, 162.0])
set([0.0, 144.0, 162.0])  # uhh okay, i would expect this to be higher up
set([0.0, 144.0, 17.0])

我观察到较高的建议项目往往具有与我试图比较的数组相同的非零值长度。我可以用我的数据做些准备来解决这个问题吗?

1 个答案:

答案 0 :(得分:2)

默认情况下,BallTree会计算向量之间的欧几里德距离,因此不适合您的计算类型。

举一个简单的例子,假设你有以下两套:

collections[0] = [1, 3]
collections[1] = [1, 2, 3]

如上所述,当您将它们转换为darray内的向量时,它们会变为:

darray[0] = [1, 3, 0]
darray[1] = [1, 2, 3]

这些之间的欧几里德距离并不反映集合中类似条目的数量,这就是为什么结果不符合您的预期。

您要查找的距离指标可能是Jaccard distance,而不是欧几里德距离,它可以衡量集合之间的相似性。 BallTree为集合的布尔表示实现了这一点;也就是说,对于上述数据,矢量将成为

darray[0] = [True, False, True]
darray[1] = [True, True, True]

其中第一个条目表示1是否在集合中,第二个条目表示2是否在集合中,依此类推。这是" one-hot编码的版本"。

对于您提供的样本数据,您可以这样计算结果:

import numpy as np
from sklearn.neighbors import BallTree
from sklearn.feature_extraction import DictVectorizer

# for replicability
np.random.seed(0)

# Compute the collections using a more efficient method
collections = [np.random.choice(300, replace=False,
                                size=np.random.randint(7, 15))
               for _ in range(10000)]

# Use DictVectorizer to compute binary representation of collections
dicts = [dict(zip(c, np.ones_like(c))) for c in collections]
darray = DictVectorizer(sparse=False, dtype=bool).fit_transform(dicts)

# Compute 15 nearest neighbors for the first collection
dist, ind = BallTree(darray, metric='jaccard').query(darray[0], k=15)
for i in ind[0]:
    print(set(collections[0]) & set(collections[i]))

我得到以下结果:

{225, 226, 261, 166, 296, 52, 150, 246, 215, 221, 223}
{52, 261, 221, 215}
{225, 226, 166, 150}
{223, 150, 215}
{225, 261, 166, 221}
{226, 261, 223}
{261, 150, 221}
{223, 52, 166, 215}
{296, 226, 166, 223}
{296, 221, 150}
{223, 52, 215}
{52, 261, 246}
{296, 225, 52}
{296, 225, 221}
{225, 150, 223}

请注意,Jaccard相似度不仅仅是交点的大小,而是通过并集大小归一化的大小。仅交叉点的大小不具有距离度量的属性,因此无法直接使用BallTree计算。

编辑:我应该补充一点,如果集合中有许多条目,则此方法变得难以维护,因为布尔编码矩阵变得太大。使用Jaccard距离计算高维邻居搜索的最佳方法可能是通过Locality Sensitive Hashing,但我不知道一个易于使用的Python实现适合这个问题。