使用matplotlib python绘制决策树分类器的2个以上功能

时间:2018-11-22 19:27:15

标签: python numpy matplotlib

数据集

Pima Indians Dataset上,我一直在使用“决策树分类器”进行分类。但是,我有自己的结果,并且作为一个明显的阶段,我一直在寻找相同的可视化效果。

这是数据集的标题:

   TimesPregnant  GlucoseConcentration  BloodPrs  SkinThickness  Serum   BMI  \
0              6                   148        72             35      0  33.6   
1              1                    85        66             29      0  26.6   
2              8                   183        64              0      0  23.3   
3              1                    89        66             23     94  28.1   
4              0                   137        40             35    168  43.1   

   DiabetesFunct  Age  Class  
0          0.627   50      1  
1          0.351   31      0  
2          0.672   32      1  
3          0.167   21      0  
4          2.288   33      1 

要绘制两个以上的特征?

这是我使用有关网络的参考资料和教程组装的代码。显然,它不适用于2个以上的功能。您可以在这里注意到,除了最后一列之外,其他所有内容都是我的功能。

代码

# Visualising the Training set results
from matplotlib.colors import ListedColormap
X_set, y_set = X_train, y_train
X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01),
                     np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
             alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
    plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
                c = ListedColormap(('red', 'green'))(i), label = j)
plt.title('Decision Tree (Train set)')
plt.xlabel('Age')
plt.ylabel('Estimated Salary')
plt.legend()
plt.show() 

您可能会注意到X1X2由网状网格组成,以便利用我正在使用的空间进行着色,但是,如果您建议的解决方案涵盖绘制的图形超过2个,则可以随意忽略尽可能在matplotlib上提供功能。

现在,我无法在这里为8个功能制作8个X,我正在寻找一种非常有效的方法来实现相同功能。

1 个答案:

答案 0 :(得分:1)

这是您的操作方式:

from itertools import product
from matplotlib import pyplot as plt
import numpy as np
import scipy.stats as sts

features = [np.linspace(0, 5),

            np.linspace(9, 14),

            np.linspace(6, 11),
            np.linspace(3, 8)]

labels = ['height',
          'weight',
          'bmi',
          'age']

n = len(features)
fig, axarr = plt.subplots(n, n, figsize=(4*n, 4*n))
fig.subplots_adjust(0, 0, 1, 1, 0, 0)

for (x,y),ax in zip(product(features, features), axarr.T.flat):
    X,Y = np.meshgrid(x, y)

    # get some fake data for demo purposes
    mnorm = sts.multivariate_normal([x.mean()**(7/10), y.mean()**(11/10)])
    Z = mnorm.pdf(np.stack([X, Y], 2))

    ax.contourf(X, Y, Z)

# label and style the plot
# ...in progress

输出:

enter image description here