A common task I have is plotting time series data and creating gray bars that denote NBER recessions. For instance, recessionplot() from Matlab will do exactly that. I am not aware of similar funcionality in Python. Hence, I wrote the following function to automate this process:
def add_nber_shade(ax: plt.Axes, nber_df: pd.DataFrame, alpha: float=0.2):
"""
Adds NBER recession shades to a singe plt.axes (tipically an "ax").
Args:
ax (plt.Axes): The ax you want to change with data already plotted
nber_df (pd.DataFrame): the Pandas dataframe with a "start" and an "end" column
alpha (float): transparency
Returns:
plt.Axes: returns the same axes but with shades
"""
min_year = pd.to_datetime(min(ax.lines[0].get_xdata())).year
nber_to_keep = nber_df[pd.to_datetime(nber_df["start"]).dt.year >= min_year]
for start, end in zip(nber_to_keep["start"], nber_to_keep["end"]):
ax.axvspan(start, end, color = "gray", alpha = alpha)
return ax
Here, nber_df that looks like the following (copying the dictionary version):
{'start': {0: '1857-07-01',
1: '1860-11-01',
2: '1865-05-01',
3: '1869-07-01',
4: '1873-11-01',
5: '1882-04-01',
6: '1887-04-01',
7: '1890-08-01',
8: '1893-02-01',
9: '1896-01-01',
10: '1899-07-01',
11: '1902-10-01',
12: '1907-06-01',
13: '1910-02-01',
14: '1913-02-01',
15: '1918-09-01',
16: '1920-02-01',
17: '1923-06-01',
18: '1926-11-01',
19: '1929-09-01',
20: '1937-06-01',
21: '1945-03-01',
22: '1948-12-01',
23: '1953-08-01',
24: '1957-09-01',
25: '1960-05-01',
26: '1970-01-01',
27: '1973-12-01',
28: '1980-02-01',
29: '1981-08-01',
30: '1990-08-01',
31: '2001-04-01',
32: '2008-01-01',
33: '2020-03-01'},
'end': {0: '1859-01-01',
1: '1861-07-01',
2: '1868-01-01',
3: '1871-01-01',
4: '1879-04-01',
5: '1885-06-01',
6: '1888-05-01',
7: '1891-06-01',
8: '1894-07-01',
9: '1897-07-01',
10: '1901-01-01',
11: '1904-09-01',
12: '1908-07-01',
13: '1912-02-01',
14: '1915-01-01',
15: '1919-04-01',
16: '1921-08-01',
17: '1924-08-01',
18: '1927-12-01',
19: '1933-04-01',
20: '1938-07-01',
21: '1945-11-01',
22: '1949-11-01',
23: '1954-06-01',
24: '1958-05-01',
25: '1961-03-01',
26: '1970-12-01',
27: '1975-04-01',
28: '1980-08-01',
29: '1982-12-01',
30: '1991-04-01',
31: '2001-12-01',
32: '2009-07-01',
33: '2020-05-01'}}
The function is very simple. It retrieves the minimum and maximum dates that were plotted, slices the given dataframe with start and end dates and then it plots the bars. There are two major ways. In one way it will work as intended, but not in the other way.
The way it works:
df = pd.DataFrame(np.random.randn(3000, 2), columns=list('AB'), index=pd.date_range(start='1970-01-01', periods=3000, freq='W'))
plt.figure()
plt.plot(df.index, df['A'], lw = 0.2)
add_nber_shade(plt.gca(), nber)
plt.show()
The way it does not work (using Pandas to plot directly)
plt.figure()
df.plot(y=["A"], lw = 0.2, ax = plt.gca(), legend=None)
add_nber_shade(plt.gca(), nber)
plt.show()
It throws out the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[106], line 3
1 plt.figure()
2 df.plot(y=["A"], lw = 0.2, ax = plt.gca(), legend=None)
----> 3 add_nber_shade(plt.gca(), nber)
4 plt.show()
File ~/Dropbox/Projects/SpanVol/src/spanvol/utilities.py:20, in add_nber_shade(ax, nber_df, alpha)
8 def add_nber_shade(ax: plt.Axes, nber_df: pd.DataFrame, alpha: float=0.2):
9 """
10 Adds NBER recession shades to a singe plt.axes (tipically an "ax").
11
(...)
18 plt.Axes: returns the same axes but with shades
19 """
---> 20 min_year = pd.to_datetime(min(ax.lines[0].get_xdata())).year
21 nber_to_keep = nber_df[pd.to_datetime(nber_df["start"]).dt.year >= min_year]
23 for start, end in zip(nber_to_keep["start"], nber_to_keep["end"]):
File ~/miniconda3/envs/volatility/lib/python3.11/site-packages/pandas/core/tools/datetimes.py:1146, in to_datetime(arg, errors, dayfirst, yearfirst, utc, format, exact, unit, infer_datetime_format, origin, cache)
1144 result = convert_listlike(argc, format)
1145 else:
-> 1146 result = convert_listlike(np.array([arg]), format)[0]
1147 if isinstance(arg, bool) and isinstance(result, np.bool_):
...
File tslib.pyx:552, in pandas._libs.tslib.array_to_datetime()
File tslib.pyx:541, in pandas._libs.tslib.array_to_datetime()
TypeError: <class 'pandas._libs.tslibs.period.Period'> is not convertible to datetime, at position 0
This is because Pandas is doing some transformation under the hood to deal with the index and is transforming it into some other class. Is there a simple way to either fix the function or some way to prevent pandas from doing it? Thanks a lot!
A simple fix could be using
x_compat=Truewhen plotting from pandas:Output: