使用gensim建立的GridSearch for doc2vec模型

时间:2018-10-18 14:12:45

标签: machine-learning gensim grid-search doc2vec hyperparameters

我正在尝试为我训练有素的doc2vec gensim模型找到最佳的超参数,该模型将文档作为输入并创建其文档嵌入。我的火车数据包含文本文档,但没有任何标签。即我只有'X',但没有'y'。

我在这里发现了一些与我想做的事情有关的问题,但是所有解决方案都是针对监督模型提出的,而没有针对像我这样的无监督模型提出的。

这是我训练doc2vec模型的代码:

def train_doc2vec(
    self,
    X: List[List[str]],
    epochs: int=10,
    learning_rate: float=0.0002) -> gensim.models.doc2vec:

    tagged_documents = list()

    for idx, w in enumerate(X):
        td = TaggedDocument(to_unicode(str.encode(' '.join(w))).split(), [str(idx)])
        tagged_documents.append(td)

    model = Doc2Vec(**self.params_doc2vec)
    model.build_vocab(tagged_documents)

    for epoch in range(epochs):
        model.train(tagged_documents,
                    total_examples=model.corpus_count,
                    epochs=model.epochs)
        # decrease the learning rate
        model.alpha -= learning_rate
        # fix the learning rate, no decay
        model.min_alpha = model.alpha

    return model

我需要有关如何使用GridSearch进行训练的模型并为其找到最佳超参数的建议,或者有关其他技术的任何建议。非常感谢您的帮助。

1 个答案:

答案 0 :(得分:2)

通过代码的正确性,我将尝试回答您有关如何执行超参数调整的问题。 您必须开始定义一组超参数,这些超参数将定义您的超参数网格搜索。对于每组超参数

Hset1 =(par1Value1,par2Value1,...,par3Value1)

您在训练集上训练模型,并使用独立的验证集来衡量您的准确性(或您希望使用的任何度量)。您存储此值(例如A_Hset1)。当对所有可能的超参数集执行此操作时,您将拥有一组度量

(A_Hset1,A_Hset2,A_Hset3 ... A_HsetK)。

其中的每一项指标都可以告诉您,每组超参数的模型效果如何? 您的一组最佳超参数

H_setOptimal = HsetX | A_setX = max(A_Hset1,A_Hset2,A_Hset3 ... A_HsetK)

为了进行公平的比较,您应该始终在相同的数据上训练模型,并始终使用相同的验证集。

我不是Python的高级用户,所以您大概可以找到更好的建议,但是我要做的是创建一个字典列表,其中每个字典都包含一组要测试的超参数:< / p>

grid_search=[{"par1":"val1","par2":"val1","par3":"val1",..., "res"=""},
             {"par1":"val2","par2":"val1","par3":"val1",..., "res"=""},
             {"par1":"val3","par2":"val1","par3":"val1",..., "res"=""},
             ,...,
             {"par1":"valn","par2":"valn","par3":"valn",..., "res"=""}]

以便您可以将结果存储在相应词典的“ res”字段中,并跟踪每组参数的性能。

for set in grid_search:
  #insert here your training and accuracy evaluation using the
  #parameters in set
  
  set["res"]= the_Accuracy_for_HyperPar_in_set

希望对您有帮助。