Python + SciKit - >手动和cross_val_score预测的不同结果

时间:2015-08-22 19:36:05

标签: python scikit-learn kaggle

我正在从Kaggle那里学习泰坦尼克号的学习任务。

如果我执行手动分离数据的线性回归或使用cross_val_score,我的预测准确度会有所不同。逻辑回归也是如此。

示例

- 线性回归。

手册

Algorithm = LinearRegression()
kf = KFold(dataset.shape[0], n_folds=3, random_state=1)
predictions = []

for train, test in kf:

    train_predictors = (dataset[Predictors].iloc[train])
    train_target = dataset['Survived'].iloc[train]
    Algorithm.fit(train_predictors, train_target)
    test_predictions = Algorithm.predict(dataset[Predictors].iloc[test])
    predictions.append(test_predictions)

predictions = np.concatenate(predictions, axis=0)
print(predictions.shape[0])
realed = list(dataset.Survived)
predictions[predictions > 0.5] = 1
predictions[predictions <= 0.5] = 0

accuracy2 = sum(predictions[predictions == dataset["Survived"]]) / len(predictions)
print("Tochnost prognoza: ", accuracy2 * 100, " %")

结果 - 78,34%

Cross_val_score

scores=cross_val_score(LinearRegression(), dataset[Predictors], dataset["Survived"], cv=3)
print(scores.mean())

结果 - 37,5%

- 逻辑回归。

这里我有26,15%的手动和78,78%的cross_val_score功能。

为什么?

1 个答案:

答案 0 :(得分:3)

对于您的代码,有几件事情看起来很不对。

  1. 您的准确度计算错误 这一行:

    accuracy2 = sum(predictions[predictions == dataset["Survived"]]) / len(predictions)
    

    不计算准确性。它做的是采取你做出正确预测时所做预测的平均值。这没有多大意义;) 这很容易修复:

    accuracy2 = sum(predictions == dataset["Survived"] / len(predictions)
    
  2. 线性回归实际上会执行回归 使用线性回归来执行分类任务不是一个好主意。在(二进制)分类中,您期望输出范围为[0; 1](概率),而线性回归通常给你一个无限的范围 由于统计学家是线性回归的忠实粉丝,他们发明了逻辑回归,这实际上是对转换目标值的线性回归。
    底线:使用逻辑回归(非线性回归)进行分类。

  3. 得分方法不是您认为的那些 cross_val_score需要scoring个参数。在这里你没有指定它(所以它是None),这意味着它将查找估算器的默认分数方法。 LinearRegression 的默认分数方法不准确。它是R ^ 2系数。这与回归有关,而不是你想要做的事情。

    所以当你这样做时:

    scores=cross_val_score(LinearRegression(), dataset[Predictors], dataset["Survived"], cv=3)
    print(scores.mean())
    

    您得到的是3倍交叉验证的平均R ^ 2系数 当您使用LogisticRegression执行此操作时,您将获得平均准确度,这是您想要的。

  4. 第1点和第2点说明了LogisticRegressioncross_val_scoreLinearRegression获得的结果。
    我还不确定第一个案例,如果我找到一个好的解释,我会更新我的帖子。我发现非常令人惊讶,因为您在计算准确性方面的错误总是低估结果。除非这不是您运行的实际代码。