import shap, pandas as pd, numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split

iris = datasets.load_iris()
X = iris.data
y = iris.target

data = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                     columns= iris['feature_names'] + ['target'])
label = data['target']
data.drop('target', axis=1, inplace=True)
X_train, X_test, y_train, y_test = train_test_split(data, label,random_state=np.random.randint(1,10), test_size=0.3)

mlp = MLPClassifier(max_iter=150).fit(X_train, y_train)                                            
mlp.score(X_test, y_test)

explainer = shap.KernelExplainer(mlp.predict_proba, shap.kmeans(X_train, 5))
shap_values = explainer.shap_values(X_test)

# First plot
shap.summary_plot(shap_values[1], feature_names = X_test.columns, plot_type='bar')

# Second, error, empty plot
import matplotlib.pyplot as plt; plt.rcdefaults()
y_pos = np.arange(len(X_test.columns))
plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
plt.xticks(y_pos, X_test.columns)
plt.ylabel('SHAP Importance')
plt.title('MLP Feature Importances')


通过跟随this guide使用shap.summary_plot,我得到如下图:

enter image description here

在我的实际数据集中,我大约有10,000个要素。我要做的是从列表中选取最重要的n个功能,并按照this guide使用matplotlib绘制它们。但是,我得到一个错误和一个空白图表:

enter image description here


TypeError                                 Traceback (most recent call last)
<ipython-input-42-ba03083152ca> in <module>
      1 import matplotlib.pyplot as plt; plt.rcdefaults()
      2 y_pos = np.arange(len(X_test.columns))
----> 3 plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
      4 plt.xticks(y_pos, X_test.columns)
      5 plt.ylabel('SHAP Importance')

c:\python367-64\lib\site-packages\matplotlib\pyplot.py in bar(x, height, width, bottom, align, data, **kwargs)
   2407     return gca().bar(
   2408         x, height, width=width, bottom=bottom, align=align,
-> 2409         **({"data": data} if data is not None else {}), **kwargs)

c:\python367-64\lib\site-packages\matplotlib\__init__.py in inner(ax, data, *args, **kwargs)
   1563     def inner(ax, *args, data=None, **kwargs):
   1564         if data is None:
-> 1565             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1567         bound = new_sig.bind(ax, *args, **kwargs)

c:\python367-64\lib\site-packages\matplotlib\axes\_axes.py in bar(self, x, height, width, bottom, align, **kwargs)
   2393                 edgecolor=e,
   2394                 linewidth=lw,
-> 2395                 label='_nolegend_',
   2396                 )
   2397             r.update(kwargs)

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, xy, width, height, angle, **kwargs)
    725         """
--> 727         Patch.__init__(self, **kwargs)
    729         self._x0 = xy[0]

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle, **kwargs)
     87         self.set_fill(fill)
     88         self.set_linestyle(linestyle)
---> 89         self.set_linewidth(linewidth)
     90         self.set_antialiased(antialiased)
     91         self.set_hatch(hatch)

c:\python367-64\lib\site-packages\matplotlib\patches.py in set_linewidth(self, w)
    393                 w = mpl.rcParams['axes.linewidth']
--> 395         self._linewidth = float(w)
    396         # scale the dash pattern by the linewidth
    397         offset, ls = self._us_dashes

TypeError: only size-1 arrays can be converted to Python scalars


# features importances sorted 
fi = (pd.Series(shap_values[1].mean(0), index=X_test.columns)

# extract 5 top values and plot


enter image description here