使用Sci-Kit Learn防止Logistic回归中的过度拟合

时间:2017-10-27 15:05:09

标签: python machine-learning scikit-learn logistic-regression data-science

我使用Logistic回归训练了一个模型来预测名称字段和描述字段是否属于男性,女性或品牌的个人资料。我的列车精度约为99%,而我的测试精度约为83%。我尝试通过调整C参数来实现正则化,但几乎没有注意到这些改进。我的训练集中有大约5,000个例子。这是一个我只需要更多数据的实例,还是我可以在Sci-Kit中做些什么来学习如何提高我的测试精度?

2 个答案:

答案 0 :(得分:1)

过度拟合是一个多方面的问题。它可能是你的火车/测试/验证分裂(从50/40/10到90/9/1的任何东西都可以改变一些事情)。您可能需要随意调整输入。尝试使用整体方法,或减少功能的数量。你可能会有异常值掉东西

然后,它可能不是这些,或者所有这些,或者这些的组合。

对于初学者,尝试将测试分数绘制为测试分割大小的函数,并看看你得到了什么

答案 1 :(得分:-2)

#The 'C' value in Logistic Regresion works very similar as the Support 
#Vector Machine (SVM) algorithm, when I use SVM I like to use #Gridsearch 
#to find the best posible fit values for 'C' and 'gamma',
#maybe this can give you some light:

# For SVC You can remove the gamma and kernel keys 
# param_grid = {'C': [0.1,1, 10, 100, 1000], 
#                'gamma': [1,0.1,0.01,0.001,0.0001], 
#                'kernel': ['rbf']} 

param_grid = {'C': [0.1,1, 10, 100, 1000]} 

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report,confusion_matrix

# Train and fit your model to see initial values
X_train, X_test, y_train, y_test = train_test_split(df_feat, np.ravel(df_target), test_size=0.30, random_state=101)
model = SVC()
model.fit(X_train,y_train)
predictions = model.predict(X_test)
print(confusion_matrix(y_test,predictions))
print(classification_report(y_test,predictions))

# Find the best 'C' value
grid = GridSearchCV(SVC(),param_grid,refit=True,verbose=3)
grid.best_params_
c_val = grid.best_estimator_.C

#Then you can re-run predictions on this grid object just like you would with a normal model.
grid_predictions = grid.predict(X_test)

# use the best 'C' value found by GridSearch and reload your LogisticRegression module
logmodel = LogisticRegression(C=c_val)
logmodel.fit(X_train,y_train)

print(confusion_matrix(y_test,grid_predictions))
print(classification_report(y_test,grid_predictions))