在2x2网格中绘制形状图

时间:2020-10-25 06:35:55

标签: python matplotlib subplot shap

我试图在2x2子图中绘制4个Shap依赖图,但无法正常工作。我尝试了以下方法:

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
ax1 = plt.subplot(221)
shap.dependence_plot('age', shap_values[1], X_train, ax=ax1)
ax2 = plt.subplot(222)
shap.dependence_plot('income', shap_values[1], X_train, ax=ax2)
ax3 = plt.subplot(223)
shap.dependence_plot('score', shap_values[1], X_train, ax=ax2)

这:

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
shap.dependence_plot('age', shap_values[1], X_train)
plt.subplot(1,2,2)
shap.dependence_plot('income', shap_values[1], X_train)
plt.subplot(2,2,3)
shap.dependence_plot('score', shap_values[1], X_train)
plt.tight_layout()
plt.show()

它一直将它们绘制在不同的行上,而不是2x2格式。

1 个答案:

答案 0 :(得分:1)

您需要将参数show=False传递给依赖图。

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
shap.dependence_plot('age', shap_values[1], X_train, ax=axes[0, 0], show=False)
shap.dependence_plot('income', shap_values[1], X_train, ax=axes[0, 1], show=False)
shap.dependence_plot('score', shap_values[1], X_train, ax=axes[1, 0], show=False)
plt.show()

This notebook中,您可以阅读有关此参数的以下内容:

''通过传递show = False可以防止shap.dependence_plot调用 matplotlib show()函数,因此您可以继续自定义图 最终打电话给自己秀''