使用sklearn计算两个不同列的单独tfidf分数

时间:2016-04-20 00:35:07

标签: python pandas scikit-learn tf-idf

我试图计算一组查询与每个查询的结果集之间的相似性。我想使用tfidf分数和余弦相似性来做到这一点。我遇到的问题是我无法弄清楚如何使用两列生成tfidf矩阵(在pandas数据帧中)。我已经连接了两列并且它工作正常,但它使用起来很尴尬,因为它需要跟踪哪个查询属于哪个结果。我如何一次计算两列的tfidf矩阵?我正在使用熊猫和sklearn。

以下是相关代码:

tf = TfidfVectorizer(analyzer='word', min_df = 0)
tfidf_matrix = tf.fit_transform(df_all['search_term'] + df_all['product_title']) # This line is the issue
feature_names = tf.get_feature_names() 

我试图将df_all [' search_term']和df_all [' product_title']作为参数传递给tf.fit_transform。这显然不起作用,因为它只是将字符串连接在一起,这使得我无法将search_term与product_title进行比较。此外,是否有更好的方法可以解决这个问题?

1 个答案:

答案 0 :(得分:5)

通过将所有单词放在一起,您已经有了一个良好的开端;通常这样的简单管道就足以产生良好的效果。您可以使用pipelinepreprocessing构建更复杂的要素处理管道。以下是它对您的数据有用的方法:

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import FeatureUnion, Pipeline

df_all = pd.DataFrame({'search_term':['hat','cat'], 
                       'product_title':['hat stand','cat in hat']})

transformer = FeatureUnion([
                ('search_term_tfidf', 
                  Pipeline([('extract_field',
                              FunctionTransformer(lambda x: x['search_term'], 
                                                  validate=False)),
                            ('tfidf', 
                              TfidfVectorizer())])),
                ('product_title_tfidf', 
                  Pipeline([('extract_field', 
                              FunctionTransformer(lambda x: x['product_title'], 
                                                  validate=False)),
                            ('tfidf', 
                              TfidfVectorizer())]))]) 

transformer.fit(df_all)

search_vocab = transformer.transformer_list[0][1].steps[1][1].get_feature_names() 
product_vocab = transformer.transformer_list[1][1].steps[1][1].get_feature_names()
vocab = search_vocab + product_vocab

print(vocab)
print(transformer.transform(df_all).toarray())

['cat', 'hat', 'cat', 'hat', 'in', 'stand']

[[ 0.          1.          0.          0.57973867  0.          0.81480247]
 [ 1.          0.          0.6316672   0.44943642  0.6316672   0.        ]]