绘制H2O模型python的决策边界

时间:2017-11-07 12:52:52

标签: python-3.x h2o

我想在Python中绘制H20随机森林模型的决策边界,如下所示:

Example of how I would like the plot to look like

到目前为止我找到的所有例子都是用scikit learn完成的。

1 个答案:

答案 0 :(得分:3)

要绘制H2O模型的决策边界,您需要使用matplotlib。要使用matplotlib,您需要在绘图之前将H2O预测转换为numpy数组或pandas数据帧。以下是二维二进制分类问题的示例:

import h2o
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from h2o.estimators.random_forest import H2ORandomForestEstimator

h2o.init()
# import the data into H2O frame
hf = h2o.import_file('data.csv')

# Convert the target into a factor for classification
hf[:,-1] = hf[:,-1].asfactor()

# Split the data into train/test
hf_train, hf_test = hf.split_frame(ratios=[0.75])

# columns used for the training
X_cols = hf_train.col_names[:-1]

# last column is the target
y_col = hf_train.col_names[-1]

# Random Forest classifier
rf_clf = H2ORandomForestEstimator(ntrees=10)
rf_clf.train(X_cols, y_col, training_frame=hf_train, validation_frame=hf_test)
y_pred = rf_clf.predict(test_data=hf_test[:,X_cols])

# Convert to pandas df and create a mesh
df = hf.as_data_frame()
x1_min, x1_max = df.ix[:, 0].min() - .5, df.ix[:, 0].max() + .5
x2_min, x2_max = df.ix[:, 1].min() - .5, df.ix[:, 1].max() + .5
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.02), 
                       np.arange(x2_min, x2_max, 0.02))

# predict the mesh values using H2O Random Forest and convert back to pandas df
Z = (rf_clf.predict(h2o.H2OFrame(np.c_[xx1.ravel(), xx2.ravel()]))).as_data_frame()
# reshape back to a 2d grid
zz = Z['p1'].values.reshape(xx1.shape)

# Plot the results
cm_scatt = ListedColormap(['b', 'r'])
fig = plt.figure(figsize=(12, 9))
cm_bright = ListedColormap(['b', 'g'])
# decision boundary
plt.contourf(xx1, xx2, zz, cmap='jet', alpha=.8)

# scatter plot of the full dataset
plt.scatter(df.ix[:, 0], df.ix[:, 1], c=df.ix[:, 2], cmap=cm_scatt,
                   edgecolors='k')
# Annotate with a model score
plt.text(xx1.max(), xx2.min(), round(rf_clf.r2(), 2), horizontalalignment='right', 
         color='w', fontsize=18)

# shutdown H2O cluster
h2o.cluster().shutdown()

H2O Random Forest Decision Boundary