一个简单的熊猫图会在图例上带有一个圆圈标记,从而产生预期的输出:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
matplotlib.pyplot.show()
请注意,“ CrudeRate”图例项是带有正确圆圈标记的直线。
但是,如果我为Holt线性指数平滑函数添加一些其他图,则图例会丢失圆形标记:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--")
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.show()
请注意,“ CrudeRate”图例项目只是一条没有圆圈标记的直线。
是什么原因导致第二种情况下的图例丢失了主图的圆圈标记?
答案 0 :(得分:3)
在matplotlib.pyplot.legend()
之前使用matplotlib.pyplot.show()
将解决您的问题。
由于您要绘制3张图,并且据我所知,图例中只需要2个标签,因此我们将label='_nolegend_'
传递给fit.fittedvalues.plot()
。如果不这样做,我们将在图例中具有值为None
的第三个标签。
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--", label='_nolegend_')
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.legend()
matplotlib.pyplot.show()
顺便提一下,为了使您更轻松地编写代码,将matplotlib.pyplot
导入为import matplotlib.pyplot as plt
后是一个好习惯。