线性回归负载模型无法预期

时间:2019-03-17 15:27:51

标签: python machine-learning scikit-learn linear-regression doc2vec

我用sklearn训练了线性回归模型,获得了5星评级,这已经足够了。我已经使用Doc2vec创建了矢量,并保存了该模型。然后,将线性回归模型保存到另一个文件。我想做的是加载Doc2vec模型和线性回归模型,并尝试预测另一条评论。

这种预测有一个非常奇怪的地方:无论它总是在2.1-3.0左右进行预测。

事实是,我建议它可以预测平均5左右(即2.5 +/-),但事实并非如此。训练模型时,我已经打印了测试数据的预测值和实际值,通常范围为1-5。所以我的想法是,代码的加载部分出了点问题。这是我的加载代码:

from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from bs4 import BeautifulSoup
from joblib import dump, load
import pickle
import re

model = Doc2Vec.load('../vectors/750000/doc2vec_model')

def cleanText(text):
    text = BeautifulSoup(text, "lxml").text
    text = re.sub(r'\|\|\|', r' ', text) 
    text = re.sub(r'http\S+', r'<URL>', text)
    text = re.sub(r'[^\w\s]','',text)
    text = text.lower()
    text = text.replace('x', '')
    return text

review = cleanText("Horrible movie! I don't recommend it to anyone!").split()
vector = model.infer_vector(review)

pkl_filename = "../vectors/750000/linear_regression_model.joblib"
with open(pkl_filename, 'rb') as file:  
    linreg = pickle.load(file)

review_vector = vector.reshape(1,-1)
predict_star = linreg.predict(review_vector)
print(predict_star)

2 个答案:

答案 0 :(得分:0)

更新:我忽略了.split()之后在问题代码中进行的.cleanText()标记化,所以这不是真正的问题。但是请保持答案以供参考&,因为真正的问题是在评论中发现的。)

通常,当用户向Doc2Vec提供纯字符串时,用户会从infer_vector()中获得神秘弱的结果。 Doc2Vec infer_vector()需要一个单词列表,不是一个字符串。

如果提供字符串,则该函数会将其视为一个字符列表的单词-根据Python将字符串建模为字符列表,以及字符和一个字符字符串的类型合并。模型可能不知道大多数这些单字符单词,而可能是'i''a'等的单词意义不大。因此,推断的doc-vector将是弱的且毫无意义。 (而且,将这样的向量反馈到线性回归中,总是给出中等的预测值也就不足为奇了。)

如果将文本分成预期的单词列表,则结果应该会有所改善。

但是更普遍的是,提供给infer_vector()的单词应经过预处理并标记为准确,而培训文档却是。

(对您是否正确进行推理的合理性测试是,为某些训练文档推断向量,然后向Doc2Vec模型询问最接近这些被重新推断向量的doc标签。通常,同一文档的训练时标签/ ID应该是最重要的结果,或者至少是最重要的几个结果。如果不是,则数据,模型参数或推断中可能还存在其他问题。)

答案 1 :(得分:0)

您的示例代码显示了joblib.dumpjoblib.load的导入-尽管本节未使用。并且,文件的后缀暗示该模型最初可能是用joblib.dump()保存的,而不是香草酱。

但是,此代码显示仅通过纯pickle.load()加载文件-这可能是错误的来源。

The joblib.load() docs建议其load()做一些事情,例如从自己的dump()创建的多个单独文件中加载numpy数组。 (奇怪的是,dump()文档对此不太清楚,但据推测dump()的返回值可能是文件名的 list 。)

您可以检查文件的保存位置,以查找似乎相关的其他文件,然后尝试使用joblib.load()而不是普通文件,以查看该文件是否加载了功能更完整的版本linreg个对象。