从KNN返回最近邻居的列表

时间:2018-10-04 10:43:07

标签: machine-learning scikit-learn

我正在尝试使用KNN模型来显示与X品牌最接近的相关品牌。我已经读入数据并进行了转置,因此格式如下:

          User1     User2     User3     User4     User5
Brand1    1         0         0         0         1
Brand2    0         0         0         1         1
Brand3    0         0         1         1         1
Brand4    1         1         1         0         1
Brand5    0         0         0         1         1

然后我定义了模型:

from sklearn.neighbors import NearestNeighbors

model_knn = NearestNeighbors(metric='cosine', algorithm='brute')
model_knn.fit(df_mini)

然后,我使用以下代码列出与随机选择的品牌最接近的5个品牌:

query_index = np.random.choice(df_mini.shape[0])
distances, indices = model_knn.kneighbors(df_mini.iloc[query_index, :].values.reshape(1, -1), n_neighbors = 6)

for i in range(0, len(distances.flatten())):
    if i == 0:
        print ('Recommendations for {0}:\n'.format(df_mini.index[query_index]))
    else:
        print ('{0}: {1}, with distance of {2}:'.format(i, df_mini.index[indices.flatten()[i]], distances.flatten()[i]))

返回这样的示例结果:

Recommendations for BRAND_X:

1: BRAND_a, with distance of 1.0:
2: BRAND_b, with distance of 1.0:
3: BRAND_c, with distance of 1.0:
4: BRAND_d, with distance of 1.0:
5: BRAND_e, with distance of 1.0:

我所有的结果都显示所有距离为1.0的品牌,在这种情况下我的代码在哪里出错?我曾尝试增加样本数据的大小,并且保持不变,这让我感到这是代码错误而不是数据怪癖?

编辑:这是我的代码的完整示例:

import pandas as pd
df = pd.read_csv('sample.csv')
print(df.head())

df_mini = df[:5000]
df_mini = df_mini.transpose()
df_mini = df_mini.drop('UserID',axis=0)

from sklearn.neighbors import NearestNeighbors

model_knn = NearestNeighbors(metric='cosine', algorithm='brute')
model_knn.fit(df_mini)

query_index = np.random.choice(df_mini.shape[0])
distances, indices = model_knn.kneighbors(df_mini.iloc[query_index, :].values.reshape(1, -1), n_neighbors = 6)

for i in range(0, len(distances.flatten())):
    if i == 0:
        print ('Recommendations for {0}:\n'.format(df_mini.index[query_index]))
    else:
        print ('{0}: {1}, with distance of {2}:'.format(i, df_mini.index[indices.flatten()[i]], distances.flatten()[i]))

样本数据文件: https://drive.google.com/open?id=19KRJDGrsLNpDD0WNAz4be76O66fGmQtJ

0 个答案:

没有答案