我运行了以下代码片段来训练文本分类模型。我选择了相当多的它并且它运行得非常顺利,但它仍然使用了大量的RAM。我们的数据集非常庞大(1300万个文档+词汇量中的1800万个单词),但在我看来,抛出错误的执行点非常奇怪。脚本:
encoder = LabelEncoder()
y = encoder.fit_transform(categories)
classes = list(range(0, len(encoder.classes_)))
vectorizer = CountVectorizer(vocabulary=vocabulary,
binary=True,
dtype=numpy.int8)
classifier = SGDClassifier(loss='modified_huber',
n_jobs=-1,
average=True,
random_state=1)
tokenpath = modelpath.joinpath("tokens")
for i in range(0, len(batches)):
token_matrix = joblib.load(
tokenpath.joinpath("{}.pickle".format(i)))
batchsize = len(token_matrix)
classifier.partial_fit(
vectorizer.transform(token_matrix),
y[i * batchsize:(i + 1) * batchsize],
classes=classes
)
joblib.dump(classifier, modelpath.joinpath('classifier.pickle'))
joblib.dump(vectorizer, modelpath.joinpath('vectorizer.pickle'))
joblib.dump(encoder, modelpath.joinpath('category_encoder.pickle'))
joblib.dump(options, modelpath.joinpath('extraction_options.pickle'))
我在这一行得到了MemoryError:
joblib.dump(vectorizer, modelpath.joinpath('vectorizer.pickle'))
在执行的这一点上,训练结束并且分类器已经被转储。它应该由垃圾收集器收集,以防需要更多内存。除此之外,如果它不是compressing the data,那么为什么joblib会分配如此多的内存。
我对python垃圾收集器的内部工作原理并不了解。我应该强迫gc.collect()或使用' del'释放那些不再需要的对象的声明?
更新
我尝试过使用HashingVectorizer,尽管它大大减少了内存使用量,但矢量化速度要慢一些,因此它不是一个很好的选择。
我必须挑选矢量化器以便稍后在分类过程中使用它,这样我就可以生成提交给分类器的稀疏矩阵。我将在此处发布我的分类代码:
extracted_features = joblib.Parallel(n_jobs=-1)(
joblib.delayed(features.extractor) (d, extraction_options) for d in documents)
probabilities = classifier.predict_proba(
vectorizer.transform(extracted_features))
predictions = category_encoder.inverse_transform(
probabilities.argmax(axis=1))
trust = probabilities.max(axis=1)
答案 0 :(得分:1)
如果您要向CountVectorizer
提供自定义词汇表,则在分类过程中重新创建词汇表应该不会有问题。当您提供字符串集而不是映射时,您可能希望使用可以访问的解析词汇表:
parsed_vocabulary = vectorizer.vocabulary_
joblib.dump(parsed_vocabulary, modelpath.joinpath('vocabulary.pickle'))
然后加载它并用于重新创建CountVectorizer
:
vectorizer = CountVectorizer(
vocabulary=parsed_vocabulary,
binary=True,
dtype=numpy.int8
)
请注意,您不需要在这里使用joblib;标准泡菜应该执行相同的;你可以使用任何可用的替代方案获得更好的结果,值得一提的是PyTables。
如果它也用于大部分内存,你应该尝试使用原始的vocabulary
来重新创建矢量化器;目前,当提供一组字符串作为词汇表时,矢量化器只是将集合转换为排序列表,因此您不必担心可重复性(尽管我会在生产中使用前仔细检查)。或者您可以将该集合转换为自己的列表。
总结一下:因为你没有fit()
Vectorizer,使用CountVectorizer
的全部附加值是它的transform()
方法;因为整个所需的数据是词汇(和参数),你可以减少记忆消耗酸洗只是你的词汇,无论是否加工。
当您从官方消息来源询问答案时,我想指出:https://github.com/scikit-learn/scikit-learn/issues/3844其中所有者和scikit-learn的撰稿人提及重新创建CountVectorizer
,尽管出于其他目的。您可以更好地报告链接仓库中的问题,但请确保包含导致内存使用量过多的数据集,以使其可重现。
最后,您可以使用前面评论中提到的HashingVectorizer
。
PS:关于gc.collect()
的使用 - 我会在这种情况下试一试;关于技术细节,你会发现很多关于解决这个问题的问题。