无论如何,要分批训练doc2vec模型

时间:2020-06-01 03:49:11

标签: gensim doc2vec

我不知道如何使用doc2vec批量训练模型。由于我将所有数据加载到ram中,因此无法加载

#Import all the dependencies
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

import nltk
nltk.download('punkt')

from nltk.tokenize import word_tokenize
#import ReadExeFileCapstone
import update-doc2vec 
mapData = ReadExeFileCapstone.readData()

# print ('mapData', mapData)

max_epochs = 10000
vec_size = 200
alpha = 0.025

model = Doc2Vec(size=vec_size,
                alpha=alpha,
                min_alpha=0.00025,
                min_count=1,
                dm =1)
data = []
for key in mapData:
    listData = mapData[key]
    # print ("listData: ", len(listData), listData)

    for i in range(len(listData)):
        listToStr = ' '.join([str(elem) for elem in listData[i]]) #convert array to list string
        data.append(listToStr)

tagged_data = [TaggedDocument(words=word_tokenize(_d.lower()), tags=[str(i)]) for i, _d in enumerate(data)]


model.build_vocab(tagged_data)
#build vocab
for epoch in range(max_epochs):
    print('iteration {0}'.format(epoch))
    model.train(tagged_data,
                total_examples=model.corpus_count,
                epochs=model.iter)
    # decrease the learning rate
    model.alpha -= 0.0002
    # fix the learning rate, no decay
    model.min_alpha = model.alpha
# train model   
model.save("d2v_ASM.model")
print("Model Saved")

1 个答案:

答案 0 :(得分:0)

Doc2Vec(以及gensim中的类似模型类)不需要完整的训练数据作为内存列表。他们将接受Python的“可迭代”对象,该对象简单地一次提供一次重复项。

这样的可迭代对象可以从其他来源流式传输项目,例如磁盘上的大文件,甚至是远远大于可用RAM的文件。

不清楚您的ReadExeFileCapstone实用工具类在做什么。 (对于该名称的代码,没有Web命中方法。)但是,可以将其更改为本身返回一个可迭代的对象,该对象每次被迭代时,一次返回原始文本中的每个文本。数据源。然后,您可以将其包装在代码中以创建必要的TaggedDocument对象,这又是可迭代的,而不是完整的内存列表。

有关该技术的合理介绍,请访问:

https://rare-technologies.com/data-streaming-in-python-generators-iterators-iterables/

另外:

    与出版的作品相比,
  • 10000个时期实在是太荒谬了,后者通常使用10-20个时期,对于非常小的数据集,有时更多。 (但是,如此小的数据集不太可能通过类似Doc2Vec的算法来获得良好的结果,该算法需要大量不同的数据。但是,如果遇到内存问题,您的数据集可能并不小。)

  • 请勿在您手动篡改train()值的循环中多次调用alpha。这是不必要且容易出错的–实际上,您当前的代码是错误的,因为从您的0.0002起始字母中减去0.025数千次会导致alpha为负,这是无意义的破坏性值。使用所需的时期数一次调用train()-这样做正确。而且很少需要调整默认的alpha值。

如果您希望获得更多的进度输出,或者只是为了更好地了解每个步骤的进展,请在INFO级别启用日志记录。