使用NearestNeighbors和word2vec来检测句子相似性

时间:2016-03-23 02:29:09

标签: python scikit-learn nearest-neighbor word2vec

我在我的语料库中使用python和gensim计算了一个word2vec模型。

然后我计算了每个句子的平均word2vec向量(平均句子中所有单词的所有向量)并将其存储在pandas数据框中。 pandas数据框df的列是:

  • 书名(句子所在的书)
  • mean-vector(句子中word2vec向量的平均值 - 大小为100)

我正在尝试使用scikit-learn NearestNeighbors来检测句子相似度(我可能会使用doc2vec,但其中一个目标是将此方法与doc2vec进行比较)。

这是我的代码:

X = df['mean_vector'].values
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)

我收到以下错误:

ValueError: setting an array element with a sequence.

我想我应该以某种方式迭代向量,以便能够在row == sentence的基础上计算每行的最近邻居,但似乎这超出了我当前(有限的)python技能。

这是df['mean_vector'][0]中第一个单元格的数据。它是在句子向量上平均的完整向量大小100。

array([ -2.14208905e-02,   2.42093615e-02,  -5.78106642e-02,
     1.32915592e-02,  -2.43393257e-02,  -1.41872400e-02,
     2.83471867e-02,  -2.02910602e-02,  -5.49359620e-02,
    -6.70913085e-02,  -5.56188896e-02,  -2.95186806e-02,
     4.97652516e-02,   7.16793686e-02,   1.81338750e-02,
    -1.50108105e-02,   1.79438610e-02,  -2.41483524e-02,
     4.97504435e-02,   2.91026086e-02,  -6.87966943e-02,
     3.27585079e-02,   5.10644279e-02,   1.97029337e-02,
     7.73109496e-02,   3.23865712e-02,  -2.81659551e-02,
    -9.69715789e-03,   5.23059331e-02,   3.81100960e-02,
    -3.62489261e-02,  -3.40068117e-02,  -4.90736961e-02,
     8.72346922e-04,   2.27111522e-02,   1.06063476e-02,
    -3.93234752e-02,  -1.10617064e-01,   8.05142429e-03,
     4.56497036e-02,  -1.73281748e-02,   2.35153548e-02,
     5.13465842e-03,   1.88336968e-02,   2.40451116e-02,
     3.79024050e-03,  -4.83284928e-02,   2.10295208e-02,
    -4.92134318e-03,   1.01532964e-02,   8.02216958e-03,
    -6.74675079e-03,  -1.39653292e-02,  -2.07276996e-02,
     9.73508134e-03,  -7.37899616e-02,  -2.58320477e-02,
    -1.10700730e-05,  -4.53227758e-02,   2.31859135e-03,
     1.40053956e-02,   1.61973312e-02,   3.01702786e-02,
    -6.96818605e-02,  -3.47468331e-02,   4.79541793e-02,
    -1.78820305e-02,   5.99209731e-03,  -5.92620336e-02,
     7.34678581e-02,  -5.23381204e-05,  -5.07357903e-02,
    -2.55154949e-02,   5.06089740e-02,  -3.70467864e-02,
    -2.04878468e-02,  -7.62404222e-03,  -5.38200373e-03,
     7.68705690e-03,  -3.27000804e-02,  -2.18365286e-02,
     2.34392099e-03,  -3.02998684e-02,   9.42565035e-03,
     3.24523374e-02,  -1.10793915e-02,   3.06244520e-03,
    -1.82240941e-02,  -5.70741761e-03,   3.13486941e-02,
    -1.15621388e-02,   1.10221673e-02,  -3.55655849e-02,
    -4.56304513e-02,   5.54837054e-03,   4.38252240e-02,
     1.57828294e-02,   2.65670624e-02,   8.08797963e-03,
     4.55569401e-02], dtype=float32)

我也尝试过:

for vec in df['mean_vector']:
X = vec
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)

但我只收到以下警告:

DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and willraise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.

如果在类似场景中使用word2vec和NearestNeighbors在github上有一个例子,我很乐意看到它。

1 个答案:

答案 0 :(得分:2)

编辑抛出错误的原因是因为sklearn需要2D输入,每个示例都在新行中。您可以使用X.reshape(1, -1)[X],第一种是更好的做法。如果没有原始数据或适当的MWE,很难说确切出错了,但我的猜测是,将数据放入或放出数据帧会出现问题。检查X.shape对您有意义。

下面是我用来检查一切对我有用的例子:

from sklearn.neighbors import NearestNeighbors
from gensim.models import Word2Vec
import numpy as np

a = """Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore
magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea 
commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla 
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est 
laborum."""
a = [x.split(' ') for x in a.split('\n') if len(x)]
model = Word2Vec(a, min_count=1)

# Get the average of all of the words to get data for a sentence
b = np.array([np.mean([model[xx] for xx in x], axis=0) for x in a])
# Check it's the correct shape
print b.shape

nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(b)