分类决策树中的学习曲线是什么意思?

时间:2019-02-10 21:49:36

标签: decision-tree grid-search

我在分析中使用了分类决策树。首先,我将整个数据分为培训和测试-60%:40%。然后,在训练集上使用GridSearch以获得得分最高的模型(max_depth = 7)。然后,我在交叉验证集和训练集上绘制了学习曲线。这是我得到的图。似乎两行重叠。那它告诉我什么呢?我的模型中没有过度拟合吗?总的来说,为什么我们需要分析中的学习曲线?

Link to my learning curve image

非常感谢!

2 个答案:

答案 0 :(得分:4)

学习曲线显示了针对不同训练样本数量的估计器的验证和训练分数。它是一种工具,可用于了解我们从添加更多训练数据中获益多少,以及估计器是否更容易受到方差误差或偏差误差的影响。

机器学习曲线可用于多种用途,包括比较不同算法、在设计期间选择模型参数、调整优化以提高收敛性以及确定用于训练的数据量。

您没有很好地利用学习曲线工具,因为您从非常大的训练规模开始,它不允许您很好地看到模型的行为。

下面是一个示例,其中显示了一个图,其中您以较小的训练规模开始分析,而另一个以非常大的训练规模开始分析(您的案例)。为此,您只需改变 sklearn.model_selection.learning_curve 的 train_sizes 参数。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from get_csv_data import HandleData
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit

def plot_learning_curve(estimator, X, y, ax=None, ylim=(0.5, 1.01), cv=None, n_jobs=4, train_sizes=np.linspace(.1, 1.0, 5)):

    train_sizes, train_scores, test_scores = \
        learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
              
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)

    # Plot learning curve
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_xlabel("Training examples")
    ax.set_ylabel("Score")
    ax.plot(train_sizes, train_scores_mean, 'o-', color="r", label="Training score")
    ax.plot(train_sizes, test_scores_mean, 'o-', color="g", label="Cross-validation score")
    ax.legend(loc="best")

    return plt

fig, (ax1, ax2) = plt.subplots(1, 2)

data = HandleData(oneHotFlag=False)
#get the data
X, y = data.get_synthatic_data()

cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
estimator = SVC()
plot_learning_curve(estimator, X, y, ax = ax1, cv=cv, train_sizes=np.linspace(.1, 1.0, 5))
plot_learning_curve(estimator, X, y, ax = ax2, cv=cv, train_sizes=np.linspace(.5, 1.0, 5))

plt.show()

output:

答案 1 :(得分:0)

您的图表显示了准确度与训练示例数量的关系。训练示例的数量越多,对模型进行训练的训练数据点的数量就越大。

训练准确性是在训练的模型上对训练的数据进行测试时的准确性得分。本质上,它已经对已经看到的数据进行了测试

在交叉验证中,数据被随机分为训练和测试集。在训练集上训练模型,并在测试集上进行测试。准确性分数反映了对测试集的预测程度。

线条是重合的,因为该模型可能受过良好训练:与预测训练过的事物一样,它在预测以前从未见过的事物方面同样出色。