我正在使用以下资源中的代码:http://www.ds100.org/sp20/resources/assets/lectures/lec25/lec25-decision-trees.html
从某些决策树中生成随机森林,以根据其收入和浏览量预测在线平台上商店的流失可能性。
UID,Viewer,Revenue,Churn
100,1000,5000,0
111,200,200,1
123,8000,12500,0
我将数据格式化为CSV格式,并使用上面链接中的代码生成了这些决策树/图形:
问题是随机林生成的结果是:
这是我使用的代码,改编自链接中的代码:
def plot_decision_tree(decision_tree_model, data = None, disable_axes = False):
from matplotlib.colors import ListedColormap
sns_cmap = ListedColormap(np.array(sns.color_palette())[0:3, :])
xx, yy = np.meshgrid(np.arange(4, 8, 0.02),
np.arange(1.9, 4.5, 0.02))
Z_string = decision_tree_model.predict(np.c_[xx.ravel(), yy.ravel()])
categories, Z_int = np.unique(Z_string, return_inverse=True)
Z_int = Z_int.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z_int, cmap=sns_cmap)
if data is not None:
sns.scatterplot(data = data, x = "Viewer", y="Revenue", hue="Churn", legend=False)
if disable_axes:
plt.axis("off")
import matplotlib.gridspec as gridspec
gs1 = gridspec.GridSpec(3, 3)
gs1.update(wspace=0.025, hspace=0.025) # set the spacing between axes.
for i in range(0, 9):
plt.subplot(gs1[i]) #3, 3, i)
plot_decision_tree(ten_decision_tree_models[i], None, True)
plt.savefig("random_forest_model_9_examples.png", dpi = 300, bbox_inches = "tight")
有人知道为什么会这样吗?