更改scikitplot中的默认线宽

时间:2019-09-10 06:26:50

标签: python matplotlib

我正在使用scikitplot库来处理某些特定的图,我正在使用示例中的代码:

import scikitplot as skplt
lr = LogisticRegression()
lr = lr.fit(X_train, y_train)
y_probas = lr.predict_proba(X_test)
skplt.metrics.plot_lift_curve(y_test, y_probas)
plt.show()

将生成以下图:

enter image description here

此图的我的问题是线条太宽,并且似乎不是改变类中参数的参数:

scikitplot.metrics.plot_lift_curve(y_true, y_probas, title='Lift Curve', ax=None, figsize=None, title_fontsize='large', text_fontsize='medium')

我尝试过:

import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1.0

但是结果没有改变

1 个答案:

答案 0 :(得分:1)

线宽看起来是hardcoded in plot_lift_curve。在这种情况下,更改rcParams不会产生任何效果,因为它仅提供默认值,并被硬编码的值取代;

一种可能的解决方案是在事实发生后更改Line2D对象的属性:

lr = LogisticRegression()
lr = lr.fit(X_train, y_train)
y_probas = lr.predict_proba(X_test)
ax = skplt.metrics.plot_lift_curve(y_test, y_probas)
for l in ax.lines:
    l.set_lw(0.5)
plt.show()

如果您只想更改某些行,则可以使用例如ax.lines[0].set_lw(1.0)