I am trying to customize the ax from matplotlib plot. Here I am using a surpyval package to fit the data and then plot it. The plot method in surpyval package does not accept arguments other than the ax=ax as I provided. My problem is i can't match the handles and legends as you can see from this example:
import surpyval as surv
import matplotlib.pyplot as plt
y_a = np. array([181, 183, 190,190, 195, 195, 198, 198, 198, 201,202, 202, 202,
204, 205, 205, 206,206, 206, 206,207, 209 , 213, 214, 218, 219])
y_s = np.array([161, 179, 196,196, 197, 198, 204, 205, 209, 211,215, 218, 227,
230, 231, 232, 232 ,236, 237, 237,240, 243, 244, 246, 252, 255])
model_1 = surv.Weibull.fit(y_a)
model_2 = surv.Weibull.fit(y_s)
ax=plt.gca()
model_1.plot(ax=ax)
model_2.plot(ax=ax)
p_a = ['A', 'a_u_CI','a_l_CI', 'a_fit']
p_s= ['S', 's_u_CI','s_l_CI', 's_fit']
p_t = p_a + p_s
ax.legend(labels=p_t[0:5:4])

As well as specifying your labels, you also need to tell
ax.legendwhich "handles" to use.Normally with matplotlib you could label each line or collection as you plot them, but with
surpyvalthis is not possible.Instead, we can find the relevant handles from the matplotlib
Axesinstance.In this case, assuming you want to label the blue and orange circles, you want the two
PathCollectioninstances, which are created using matplotlib'sscattermethod under the hood. Fortunately, these are easy to access, in theax.collectionsattribute. So, you just need to change the legend line to:to get this plot:
For reference, the solid red lines, and dashed black lines, are stored in
ax.lines. For this plot, there are 6Line2Dobjects stored in there, corresponding to the 4 red lines and 2 black lines on the plot.In answer to the comment below ("Do you know how to turn the solid lines into fill_between?"), to fill between two of the lines, you can grab the x- and y-data from the two lines, and then use matplotlib's fill_between function.
For example, to fill between the two red lines for the "A" data, you could do this:
Note that you need to set the zorder to something lower than the scatter points, otherwise it'll hide them.
For the other two lines in this plot, you would want to use
ax.lines[3]andax.lines[4]Note this is all tested with
matplotlibv3.6,surpyvalv0.10.10,python3.9.I found that using a later release of matplotlib (v3.7.2, Sept 2023) did not work, due to some issues with the way surpyval tries to plot the grid.