sklearn.model_selection.GridSearchCV的LatentDirichletAllocation评分策略

时间:2018-10-25 09:46:29

标签: python scikit-learn nlp lda

我正在尝试使用sklearn库将GridSearchCV应用于LatentDirichletAllocation。

当前管道如下所示:

vectorizer = CountVectorizer(analyzer='word',       
                         min_df=10,                      
                         stop_words='english',           
                         lowercase=True,                 
                         token_pattern='[a-zA-Z0-9]{3,}'
                        )

data_vectorized = vectorizer.fit_transform(doc_clean) #where doc_clean is processed text.

lda_model = LatentDirichletAllocation(n_components =number_of_topics,
                                    max_iter=10,            
                                    learning_method='online',   
                                    random_state=100,       
                                    batch_size=128,         
                                    evaluate_every = -1,    
                                    n_jobs = -1,            
                                    )

search_params = {'n_components': [10, 15, 20, 25, 30], 'learning_decay': [.5, .7, .9]}
model = GridSearchCV(lda_model, param_grid=search_params)
model.fit(data_vectorized)

当前,GridSearchCV使用近似对数似然作为得分来确定哪个是最佳模型。我想做的是将评分方法改为基于模型的approximate perplexity

根据sklearn的documentation of GridSearchCV,我可以使用一个得分参数。但是,我不知道如何将困惑作为一种评分方法,并且我在网上找不到任何使用困惑的例子。这可能吗?

2 个答案:

答案 0 :(得分:1)

GridSearchCV的默认设置将使用管道中最终估算器的score()函数。

make_scorer可以在这里使用,但是要计算困惑度,您还需要来自拟合模型的其他数据,通过make_scorer提供这些数据可能有点复杂。

您可以在此处对LDA进行包装,并可以在其中重新实现score()函数以返回perplexity。大致情况:

class MyLDAWithPerplexityScorer(LatentDirichletAllocation):

    def score(self, X, y=None):

        # You can change the options passed to perplexity here
        score = super(MyLDAWithPerplexityScorer, self).perplexity(X, sub_sampling=False)

        # Since perplexity is lower for better, so we do negative
        return -1*score

然后可以使用它代替代码中的LatentDirichletAllocation,例如:

...
...
...
lda_model = MyLDAWithPerplexityScorer(n_components =number_of_topics,
                                ....
                                ....   
                                n_jobs = -1,            
                                )
...
...

答案 1 :(得分:0)

得分和困惑度参数似乎有问题,并且取决于主题的数量。因此,网格中的结果将为您提供最少的主题

GitHub issue