GradientBoostingClassifier的apply函数混淆了

时间:2018-05-01 05:12:23

标签: python scikit-learn

对于应用功能,您可以参考here

我的混淆更多来自this sample,我在下面的代码片段中添加了一些打印输出更多的调试信息,

price

输出如下所示,并且混淆了grd = GradientBoostingClassifier(n_estimators=n_estimator) grd_enc = OneHotEncoder() grd_lm = LogisticRegression() grd.fit(X_train, y_train) test_var = grd.apply(X_train)[:, :, 0] print "test_var.shape", test_var.shape print "test_var", test_var grd_enc.fit(grd.apply(X_train)[:, :, 0]) grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr) 6.3.这些数字的含义?以及它们与最终分类结果的关系如何?

10.

1 个答案:

答案 0 :(得分:4)

要了解渐变增强,首先需要了解各个树。我将展示一个小例子。

以下是设置:在Iris数据集上训练的小型GB模型,用于预测花是否属于2级。

# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=2, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()
print(len(trees)) # 5
# there are 150 observations, each is encoded by 5 trees, each tree has 1 output
applied = model.apply(X) 
print(applied.shape) # (150, 5, 1)
print(applied[0].T) # [[2. 2. 2. 5. 2.]] - a single row of the apply() result
print(X[0]) # [5.1 3.5 1.4 0.2] - the pbservation corresponding to that row
print(trees[0].apply(X[[0]])) # [2] - 2 is the result of application the 0'th tree to the sample
print(trees[3].apply(X[[0]])) # [5] - 5 is the result of application the 3'th tree to the sample

您可以看到[2. 2. 2. 5. 2.]生成的序列model.apply()中的每个数字对应于单个树的输出。但这些数字意味着什么?

我们可以通过目视检查轻松分析决策树。这是一个绘制一个

的函数
# a function to draw a tree. You need pydotplus and graphviz installed 
# sudo apt-get install graphviz
# pip install pydotplus

from sklearn.externals.six import StringIO  
from IPython.display import Image  
from sklearn.tree import export_graphviz
import pydotplus
def plot_tree(clf):
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data, node_ids=True,
                    filled=True, rounded=True, 
                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    return Image(graph.create_png())

# now we can plot the first tree
plot_tree(trees[0])

enter image description here

您可以看到每个节点都有一个数字(从0到6)。如果我们将单个示例推送到此树中,它将首先转到节点#1(因为功能x3具有值0.2 < 1.75),然后转到节点#2(因为功能{{1} }具有值x2

以同样的方式,我们可以分析产生输出1.4 < 4.95的树3:

5

enter image description here

这里我们的观察首先是节点#4,然后是节点#5,因为plot_tree(trees[3]) x1=3.5>2.25。因此,它最终得到5号。

就这么简单! x2=1.4<4.85生成的每个数字都是样本结束的相应树节点的序号。

这些数字与最终分类结果的关系是通过相应树中叶子的apply()。在二进制分类的情况下,所有叶子中的value只是加起来,如果是正数,那么“正”&#39;获胜,否则“否定”#39;类。在多类分类的情况下,每个类的值相加,总值最大的类获胜。

在我们的例子中,第一棵树(其节点#2)给出值-1.454,​​其他树也给出一些值,它们的总和是-4.84。这是否定的,因此,我们的例子不属于第2类。

value