How to prevent Pandas from plotting index as Period?

31 Views Asked by At

A common task I have is plotting time series data and creating gray bars that denote NBER recessions. For instance, recessionplot() from Matlab will do exactly that. I am not aware of similar funcionality in Python. Hence, I wrote the following function to automate this process:

def add_nber_shade(ax: plt.Axes, nber_df: pd.DataFrame, alpha: float=0.2):
    """
    Adds NBER recession shades to a singe plt.axes (tipically an "ax").

    Args:
        ax (plt.Axes): The ax you want to change with data already plotted
        nber_df (pd.DataFrame): the Pandas dataframe with a "start" and an "end" column
        alpha (float): transparency

    Returns:
        plt.Axes: returns the same axes but with shades
    """
    min_year = pd.to_datetime(min(ax.lines[0].get_xdata())).year
    nber_to_keep = nber_df[pd.to_datetime(nber_df["start"]).dt.year >= min_year]
    
    for start, end in zip(nber_to_keep["start"], nber_to_keep["end"]):
       ax.axvspan(start, end, color = "gray", alpha = alpha)
    
    return ax

Here, nber_df that looks like the following (copying the dictionary version):

{'start': {0: '1857-07-01',
  1: '1860-11-01',
  2: '1865-05-01',
  3: '1869-07-01',
  4: '1873-11-01',
  5: '1882-04-01',
  6: '1887-04-01',
  7: '1890-08-01',
  8: '1893-02-01',
  9: '1896-01-01',
  10: '1899-07-01',
  11: '1902-10-01',
  12: '1907-06-01',
  13: '1910-02-01',
  14: '1913-02-01',
  15: '1918-09-01',
  16: '1920-02-01',
  17: '1923-06-01',
  18: '1926-11-01',
  19: '1929-09-01',
  20: '1937-06-01',
  21: '1945-03-01',
  22: '1948-12-01',
  23: '1953-08-01',
  24: '1957-09-01',
  25: '1960-05-01',
  26: '1970-01-01',
  27: '1973-12-01',
  28: '1980-02-01',
  29: '1981-08-01',
  30: '1990-08-01',
  31: '2001-04-01',
  32: '2008-01-01',
  33: '2020-03-01'},
 'end': {0: '1859-01-01',
  1: '1861-07-01',
  2: '1868-01-01',
  3: '1871-01-01',
  4: '1879-04-01',
  5: '1885-06-01',
  6: '1888-05-01',
  7: '1891-06-01',
  8: '1894-07-01',
  9: '1897-07-01',
  10: '1901-01-01',
  11: '1904-09-01',
  12: '1908-07-01',
  13: '1912-02-01',
  14: '1915-01-01',
  15: '1919-04-01',
  16: '1921-08-01',
  17: '1924-08-01',
  18: '1927-12-01',
  19: '1933-04-01',
  20: '1938-07-01',
  21: '1945-11-01',
  22: '1949-11-01',
  23: '1954-06-01',
  24: '1958-05-01',
  25: '1961-03-01',
  26: '1970-12-01',
  27: '1975-04-01',
  28: '1980-08-01',
  29: '1982-12-01',
  30: '1991-04-01',
  31: '2001-12-01',
  32: '2009-07-01',
  33: '2020-05-01'}}

The function is very simple. It retrieves the minimum and maximum dates that were plotted, slices the given dataframe with start and end dates and then it plots the bars. There are two major ways. In one way it will work as intended, but not in the other way.

The way it works:

df = pd.DataFrame(np.random.randn(3000, 2), columns=list('AB'), index=pd.date_range(start='1970-01-01', periods=3000, freq='W'))

plt.figure()
plt.plot(df.index, df['A'], lw = 0.2)
add_nber_shade(plt.gca(), nber)
plt.show()

The way it does not work (using Pandas to plot directly)

plt.figure()
df.plot(y=["A"], lw = 0.2, ax = plt.gca(), legend=None)
add_nber_shade(plt.gca(), nber)
plt.show()

It throws out the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[106], line 3
      1 plt.figure()
      2 df.plot(y=["A"], lw = 0.2, ax = plt.gca(), legend=None)
----> 3 add_nber_shade(plt.gca(), nber)
      4 plt.show()

File ~/Dropbox/Projects/SpanVol/src/spanvol/utilities.py:20, in add_nber_shade(ax, nber_df, alpha)
      8 def add_nber_shade(ax: plt.Axes, nber_df: pd.DataFrame, alpha: float=0.2):
      9     """
     10     Adds NBER recession shades to a singe plt.axes (tipically an "ax").
     11 
   (...)
     18         plt.Axes: returns the same axes but with shades
     19     """
---> 20     min_year = pd.to_datetime(min(ax.lines[0].get_xdata())).year
     21     nber_to_keep = nber_df[pd.to_datetime(nber_df["start"]).dt.year >= min_year]
     23     for start, end in zip(nber_to_keep["start"], nber_to_keep["end"]):

File ~/miniconda3/envs/volatility/lib/python3.11/site-packages/pandas/core/tools/datetimes.py:1146, in to_datetime(arg, errors, dayfirst, yearfirst, utc, format, exact, unit, infer_datetime_format, origin, cache)
   1144         result = convert_listlike(argc, format)
   1145 else:
-> 1146     result = convert_listlike(np.array([arg]), format)[0]
   1147     if isinstance(arg, bool) and isinstance(result, np.bool_):
...
File tslib.pyx:552, in pandas._libs.tslib.array_to_datetime()

File tslib.pyx:541, in pandas._libs.tslib.array_to_datetime()

TypeError: <class 'pandas._libs.tslibs.period.Period'> is not convertible to datetime, at position 0

This is because Pandas is doing some transformation under the hood to deal with the index and is transforming it into some other class. Is there a simple way to either fix the function or some way to prevent pandas from doing it? Thanks a lot!

1

There are 1 best solutions below

1
BigBen On BEST ANSWER

A simple fix could be using x_compat=True when plotting from pandas:

df.plot(y=["A"], lw = 0.2, ax = plt.gca(), legend=None, x_compat=True)

Output:

enter image description here