用于 Sickit 学习中的 One vs rest 的优化求解器

时间:2021-03-03 21:22:14

标签: python machine-learning scikit-learn logistic-regression

我正在尝试使用逻辑回归解决多类分类问题。我的数据集有 3 个不同的类,每个数据点只属于一个类。这是样本 training_data; enter image description here

这里的第一列是我作为偏置项添加的向量。并且目标列已经使用标签二值化的概念进行二值化,如sickit-learn

然后我得到的目标如下;

array([[1, 0, 0],
   [1, 0, 0],
   [0, 1, 0],
   [1, 0, 0],
   [1, 0, 0]])

接下来,我使用 one vs. rest 的概念来训练它,即一次训练一个分类器。示例代码;


for i in range(label_train.shape[1]):
    clf = LogisticRegression(random_state=0,multi_class='ovr', solver='liblinear',fit_intercept=True).\
 fit(train_data_copy, label_train[:,i])
    #print(clf.coef_.shape)

如您所见,我总共训练了 3 个分类器,每个分类器对应一个。我在这里有两个问题;

第一个问题:根据sickit-learn 文档,

<块引用>

multi_class{‘auto’, ‘ovr’, ‘multinomial’}, default='auto' 如果选择的选项是“ovr”,那么每个标签都适合一个二元问题。对于“多项式”,最小化的损失是拟合整个概率分布的多项式损失,即使数据是二进制的。当 solver='liblinear' 时,'multinomial' 不可用。如果数据是二进制的,或者如果solver='liblinear','auto'选择'ovr',否则选择'multinomial'。

我的问题是,既然我选择了求解器作为 liblinear(作为 o.v.r 问题),那么我选择 multi_class 作为 auto 还是 ovr 有关系吗。

第二个问题,关于截距(或偏差)项。文档说,如果 fit_intercept=True 然后一个偏差项被添加到决策函数中。但是我注意到,当我没有将 1 的向量添加到我的数据矩阵时,系数中的参数数量,theta 向量与特征数量相同,尽管 fit_intercept=True。我的问题是,我们是否必须将 1 的向量添加到数据矩阵中,并启用 fit_intercept 以便将偏差项添加到决策函数中。

1 个答案:

答案 0 :(得分:0)

  1. 没关系;正如您所看到的 here,无论何时选择 multi_class='auto'multi_class='ovr' 都会导致相同的结果solver='liblinear'
  2. 如果 solver='liblinear' 使用等于 1 的默认偏差项并通过 intercept_scaling 属性附加到 X(这反过来仅在 fit_intercept=True 时有用),如您所见{3}}。您将在拟合后获得由 (n_classes,) 返回的拟合偏差(维度 intercept_)(如果 fit_intercept=False 为零值)。拟合系数由 coef_(维度 (n_classes, n_features) 而不是 (n_classes, n_features + 1) - 拆分完成 here)返回。

这里有一个例子,考虑 Iris 数据集(有 3 个类和 4 个特征):

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
X, y = load_iris(return_X_y=True)

clf = LogisticRegression(random_state=0, fit_intercept=True, multi_class='ovr', solver='liblinear')
clf.fit(X, y)
clf.intercept_, clf.coef_
################################
(array([ 0.26421853,  1.09392467, -1.21470917]),
 array([[ 0.41021713,  1.46416217, -2.26003266, -1.02103509],
        [ 0.4275087 , -1.61211605,  0.5758173 , -1.40617325],
        [-1.70751526, -1.53427768,  2.47096755,  2.55537041]]))
相关问题