可视化RandomForestRgressor中的所有树

时间:2019-04-27 22:57:56

标签: scikit-learn

我已经在数据集上训练了RandomForestRegressor模型,并在训练集上对其进行了训练,因为训练精度(Pearsons r)很好(0.9),所以我继续在测试集和测试准确度(Pearsons r)也不错,为0.89。这已应用于土木工程问题。

Rand_Fore_Cla = RandomForestRegressor(n_estimators = 10,random_state = 42,min_samples_leaf = 6)

我现在想以png格式可视化所有十棵树,以便对于现场工程师进行任何新的预测都不必使用计算机,他可以检查这10棵估计器树并为给定的树进行预测X_new。

我可以通过以下代码提取全部10棵树(应用10次)

第一棵树:

export_graphviz(Rand_Fore_Cla.estimators_ [0],out_file =“ tree_Rand_Forest_tree1.dot”,feature_names = x_train.columns,filled = True,rounded = True)

os.system('dot -Tpng tree_Rand_Forest.dot -o tree_Rand_Forest.png')

第二棵树:

export_graphviz(Rand_Fore_Cla.estimators_ [1],out_file =“ tree_Rand_Forest_tree1.dot”,feature_names = x_train.columns,filled = True,rounded = True)

os.system('dot -Tpng tree_Rand_Forest.dot -o tree_Rand_Forest.png')

所有十棵树都这样……

现在,对于给定的X_new,我遍历了所有10棵树,记下了10个y_pred值,然后计算了平均值,但该平均值与使用预测方法获得的值不匹配

y_pred = Rand_Fore_Cla.predict(X_new)

0 个答案:

没有答案