为什么我的随机森林图具有相同的颜色?

时间:2020-06-10 09:20:34

标签: python pandas decision-tree

我正在使用以下资源中的代码:http://www.ds100.org/sp20/resources/assets/lectures/lec25/lec25-decision-trees.html

从某些决策树中生成随机森林,以根据其收入和浏览量预测在线平台上商店的流失可能性。

UID,Viewer,Revenue,Churn
100,1000,5000,0
111,200,200,1
123,8000,12500,0

我将数据格式化为CSV格式,并使用上面链接中的代码生成了这些决策树/图形:

Plot produced

问题是随机林生成的结果是:

Random forest

这是我使用的代码,改编自链接中的代码:

def plot_decision_tree(decision_tree_model, data = None, disable_axes = False):
    from matplotlib.colors import ListedColormap
    sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])

    xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
                     np.arange(1.9, 4.5, 0.02))

    Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
    categories, Z_int = np.unique(Z_string, return_inverse=True)
    Z_int = Z_int.reshape(xx.shape)
    cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
    if data is not None:
        sns.scatterplot(data = data, x = "Viewer", y="Revenue", hue="Churn", legend=False)

    if disable_axes:
        plt.axis("off")
import matplotlib.gridspec as gridspec
gs1 = gridspec.GridSpec(3, 3)
gs1.update(wspace=0.025, hspace=0.025) # set the spacing between axes. 

for i in range(0, 9):
    plt.subplot(gs1[i]) #3, 3, i)
    plot_decision_tree(ten_decision_tree_models[i], None, True)    

plt.savefig("random_forest_model_9_examples.png", dpi = 300, bbox_inches = "tight")    

有人知道为什么会这样吗?

0 个答案:

没有答案