将Logistic回归模型拟合到MNIST数据需要很长时间

时间:2019-04-26 01:07:01

标签: machine-learning scikit-learn classification logistic-regression

enter image description here我正在尝试将sklearn的LogisticRegression模型应用于MNIST数据集,并且我将训练-测试数据分为70-30个部分。

但是,当我简单地说 model.fit(train_x, train_y)需要很长时间。

启动logistic回归时我没有添加任何参数。

代码:

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_mldata
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

import tempfile

test_data_home = tempfile.mkdtemp()
mnist = fetch_mldata('MNIST original', data_home = test_data_home)


x_train, x_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size = 0.30, random_state = 0)

lr = LogisticRegression(penalty = 'l2')
lr.fit(x_train, y_train)

2 个答案:

答案 0 :(得分:1)

您提出的问题似乎还很模糊,但是我很确定您的逻辑回归没有收敛。除非您担心过拟合,否则我不确定自己为什么现在要包括“ L2”惩罚条款。无论如何,如果您查看sklearn docs,它会说:

用于优化问题的算法。

对于小型数据集,“ liblinear”是一个不错的选择,而对于大型数据集,“ sag”和“ saga”则更快。 对于多类问题,只有“ newton-cg”,“ sag”,“ saga”和“ lbfgs”处理多项式损失; “ liblinear”仅限于“一站式”计划。 “ newton-cg”,“ lbfgs”和“ sag”仅处理L2惩罚,而“ liblinear”和“ saga”处理L1惩罚。 请注意,只有在比例大致相同的要素上才能确保“ sag”和“ saga”快速收敛。您可以使用sklearn.preprocessing中的缩放器对数据进行预处理。

我会立即建议您添加参数“ solver = sag”(或任何其他可以处理L2罚分的求解器),因为文档明确指出只有某些求解器可以处理L2罚分,而默认求解器仅是liblinear处理L1罚款。关于逻辑回归的求解器上有一篇非常不错的文章,您可以查看它的数据集:
Solvers for Logistic Regression

请记住,L2和L1正则化是为了处理过度拟合,因此,您甚至可以在lr定义中更改 C 参数。请查看sklearn文档以获取更多信息。希望这可以帮助。

答案 1 :(得分:0)

首先,MINST不是二进制分类,而是多分类。因此,关于scikit-learn中的文档:

  

multi_class:str,{‘ovr’,‘多项式’,‘自动’},默认值:‘ovr’   选择的选项是“ ovr”,则每个选项都适合一个二进制问题   标签。对于“多项式”,最小化的损失就是多项式损失   拟合整个概率分布,即使数据是   二进制当solver =“ liblinear”时,“多项式”不可用。 '汽车'   如果数据是二进制数据,或者如果Solver =“ liblinear”,则选择“ ovr”,然后   否则选择“多项式”。

您需要在模型创建中强调它。

由于MINST具有相同幅度的特征,因此我相信如果您明确地将求解器称为“传奇”,则其收敛速度将比其他求解器快。

因此,我以Scikitlearn here的示例为例,设置训练参数,并将代码更改为:

lr = LogisticRegression(C=50. / train_samples,
                         multi_class='multinomial',
                         penalty='l1', solver='saga', tol=0.1)
lr.fit(x_train, y_train)