使用 SHAP 来解释 DNN 模型,但我的 summary_plot 只显示了每个特征的平均影响,并不包括所有特征

时间:2021-04-18 22:44:01

标签: python keras neural-network shap

所以我正在生成一个形状汇总图,如下所示:

explainer = shap.KernelExplainer(model, X_test[:100,:])
shap_values = explainer.shap_values(X_test[:100,:])
fig = shap.summary_plot(shap_values, features=X_test[:100,:], feature_names=feature_names, show=False)
plt.savefig('test.png')

这可以正常工作并创建一个如下所示的图:

my summary plot

这看起来没问题,但有几个问题。通过阅读 shap summary_plots,我经常看到如下所示:

example summary plot how i want mine to look

如您所见 - 这看起来与我的有点不同。根据两个summary_plots底部的文本,我的看起来像是在显示每个特征的平均形状值,而我在网上看到的只是显示每个特征的每个单独的数据点——换句话说,我在网上看到的那些看起来更多颗粒状。

如何创建一个不显示每个特征的平均影响而只显示每个数据点的摘要图?我认为summary_plot() 必须有一个布尔参数,比如use_average 之类的,但找不到任何东西。

此外,正如您在我的 summary_plot 中看到的那样 - y 轴上仅包含 20 个特征。我的模型实际上有大约 100 个特征,如果可能,我想将所有特征都包含在 summary_plot 中。我认为 shap 默认显示为 20,但我希望有办法增加这个数字。

1 个答案:

答案 0 :(得分:1)

我的理解是 shap.summary_plot 仅绘制条形图,当模型有多个输出时,或者即使 SHAP 认为它有多个输出(在我的情况下是这样)。当我尝试使用 summary_plot 的 plot_type 选项将绘图强制为“点”时,它给了我一个断言错误来解释这个问题。

您可以尝试复制该错误消息:

shap.summary_plot(shap_values, x_train, plot_type='dot', show = False)

如果您遇到相同的错误,请尝试对模型中的第一个输出变量执行此操作:

shap.summary_plot(shap_values[0], x_train, show = False)

这似乎解决了我的问题。

至于尝试增加参数的数量,我相信 max_display 选项应该会有所帮助,虽然我没有尝试过超过 20(我的模型不是那么大):

shap.summary_plot(shap_values[0], x_train, max_display = 5, show = False)

希望能帮到你。祝你好运:)