为什么在进行多个绘图时,图例图例会丢失标记?

时间:2019-02-20 16:40:44

标签: python pandas matplotlib statsmodels

一个简单的熊猫图会在图例上带有一个圆圈标记,从而产生预期的输出:

import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api

cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""

df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
matplotlib.pyplot.show()

Simple pandas plot

请注意,“ CrudeRate”图例项是带有正确圆圈标记的直线。

但是,如果我为Holt线性指数平滑函数添加一些其他图,则图例会丢失圆形标记:

import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api

cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""

def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
  fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
  fit.fittedvalues.plot(color=color, style="--")
  title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
  forecast = fit.forecast(predict).rename("${}$".format(title))
  forecast.plot(color=color, legend=True, style="--")

df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.show()

Legend missing marker

请注意,“ CrudeRate”图例项目只是一条没有圆圈标记的直线。

是什么原因导致第二种情况下的图例丢失了主图的圆圈标记?

1 个答案:

答案 0 :(得分:3)

matplotlib.pyplot.legend()之前使用matplotlib.pyplot.show()将解决您的问题。

由于您要绘制3张图,并且据我所知,图例中只需要2个标签,因此我们将label='_nolegend_'传递给fit.fittedvalues.plot()。如果不这样做,我们将在图例中具有值为None的第三个标签。

import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api

cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""

def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
  fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
  fit.fittedvalues.plot(color=color, style="--", label='_nolegend_')
  title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
  forecast = fit.forecast(predict).rename("${}$".format(title))
  forecast.plot(color=color, legend=True, style="--")

df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.legend()
matplotlib.pyplot.show()

enter image description here

顺便提一下,为了使您更轻松地编写代码,将matplotlib.pyplot导入为import matplotlib.pyplot as plt后是一个好习惯。