How do I update the matplotlib elements of a sympy plot?

96 Views Asked by At

This following code below consists of a matplotlib graph of the function y > 5/x, with the ability to fill in the graph as the user pans/zooms outward.

import numpy as np
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

x_range = np.linspace(-10, 10, 400)
y_range = 5 / x_range

line, = ax.plot(x_range, y_range, 'r', linewidth=2, linestyle='--')

ax.fill_between(x_range, y_range, y_range.max(), alpha=0.3, color='gray')

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Inequality: y > 5 / x')

ax.axhline(0, color='black',linewidth=3)
ax.axvline(0, color='black',linewidth=3)

ax.grid(color='gray', linestyle='--', linewidth=0.5)

def update_limits(event):
    xlim = ax.get_xlim()
    x_range = np.linspace(xlim[0], xlim[1], max(200, int(200 * (xlim[1] - xlim[0]))))
    y_range = 5 / x_range
    line.set_data(x_range, y_range)
    for collection in ax.collections:
        collection.remove()
    ax.fill_between(x_range, y_range, max(y_range.max(), ax.get_ylim()[1]), alpha=0.3, color='gray')
    plt.draw()

fig.canvas.mpl_connect('button_release_event', update_limits)
plt.show()

I've been trying to convert this concept into code that uses the sympy module (by accessing the Matplotlib backends), which has better control over algebraic functions. However, the following code below, does not seem to fill in the graph as the user pans outward. Why is this the case and how do I fix it?

import sympy as sp
from sympy.plotting.plot import MatplotlibBackend

x, y = sp.symbols('x y')

# Define the implicit plot
p1 = sp.plot_implicit(sp.And(y > 5 / x), (x, -10, 10), (y, -10, 10), show=False)

mplPlot = MatplotlibBackend(p1)

mplPlot.process_series()
mplPlot.fig.tight_layout()
mplPlot.ax[0].set_xlabel("x-axis")
mplPlot.ax[0].set_ylabel("y-axis")


def update_limits(event):
    global mplPlot
    xmin, xmax = mplPlot.ax[0].get_xlim()
    ymin, ymax = mplPlot.ax[0].get_ylim()

    p1 = sp.plot_implicit(sp.And(y > 5 / x), (x, xmin, xmax), (y, ymin, ymax), show=False)

    mplPlot = MatplotlibBackend(p1)
    mplPlot.process_series()
    mplPlot.fig.tight_layout()

    mplPlot.ax[0].set_xlabel("x-axis")
    mplPlot.ax[0].set_ylabel("y-axis")
    mplPlot.plt.draw()



mplPlot.fig.canvas.mpl_connect('button_release_event', update_limits)
mplPlot.plt.show()

UPDATE: After some debugging (with the sympy code), I have found that the xmax, xmin variables in the update_limits function are only changed once and they stay that way for the rest of the duration that the program is run. If possible I would also like to know why this is.

UPDATE 2: If instead of running mplPlot.plt.draw(), you run mplPlot.plt.show(), a new window is created with the correct graph. This is not what I want, as I want the changes to be put on the same window. Another buggy behavior is revealed when I do this as well, which is when I pan far enough into Quadrant IV of the graph, the graph seems to become unresponsive. This doesn't happen all the time and I can't find an explanation for it. If anyone knows why this is the case, feel free to add that into your answer!

1

There are 1 best solutions below

5
Davide_sd On

I'm going to use the SymPy Plotting Backend's module, because it already has a lot of code written for interactivity. At the time of writing this answer, version 3.1.1 is out, but it doesn't implement the pan/zoom/etc. events. Still, we can easily implement them.

The first thing you can try is an implicit plot:

%matplotlib widget
from sympy import *
from spb import *
from matplotlib.colors import ListedColormap
var("x, y")
     
g = graphics(
    implicit_2d(
        y>5/x, (x, -10, 10), (y, -200, 200), n=500,
        rendering_kw={"alpha": 0.3, "cmap": ListedColormap(["#ffffff00", "gray"])},
        border_kw={"linestyles": "--", "cmap": ListedColormap(["r", "r"])}
    ),
    xlabel="x", ylabel="y", title="y > 5/x", show=False
)

def _update_axis_limits(event):
    xlim = g.ax.get_xlim()
    ylim = g.ax.get_ylim()
    limits = [xlim, ylim]
    all_params = {}
    for s in g.series:
        new_ranges = []
        for r, l in zip(s.ranges, limits):
            new_ranges.append((r[0], *l))
        s.ranges = new_ranges
        # inform the data series that they must generate new data on next update
        s._interactive_ranges = True
        s.is_interactive = True
        # extract any existing parameters
        all_params = g.merge({}, all_params)
    # create new data and update the plot
    g.update_interactive(all_params)

g.show(block=False)
g.fig.canvas.mpl_connect('button_release_event', _update_axis_limits)

enter image description here

There is a problem with this visualization: the function is undefined at x=0, so there shouldn't be that vertical red dashed line. Implicit plotting relies on Matplotlib's contour functionality: it is extremely hard to implement undefined "points/lines" on contour plotting.

We can improve the visualization by using fill_between, like in your first approach. However, the plotting module doesn't implement that functionality. We must create it, like this:

import numpy as np
from sympy import *
from spb import *
from spb.series import LineOver1DRangeSeries
from spb.backends.matplotlib.renderers.renderer import MatplotlibRenderer

# Things to know in order to extend the plotting module capabilities:
#
# 1. data is generated by some data series.
# 2. each data series is mapped to a particular renderer

# Let's create a data series.
# Matplotlib's `fill_between`: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.fill_between.html
# requires the coordinates of the first curve and second curve.
# The first curve is just an ordinary line corresponding to the y-values of your expressions.
# The second curve is the maximum value of the y-coordinate of the first curve.
# So, the data is identical to the one we would use to render a line.
class FillBetweenSeries(LineOver1DRangeSeries):
    pass

# Next, we must create a renderer, which must have two methods:
#
# 1. a `draw` method, where the initial handle will be created.
# 2. an `update` method, where the handle will be updated with new data.
# 
# More information can be found here: https://sympy-plot-backends.readthedocs.io/en/latest/tutorials/tut-6.html
def draw(renderer, data):
    ax = renderer.plot.ax
    s = renderer.series
    merge = renderer.plot.merge
    x, y = data
    rkw = merge({}, {"y2": np.nanmax(y)}, s.rendering_kw)
    handles = [
        ax.fill_between(x, y, **rkw)
    ]
    return handles

def update(renderer, data, handles):
    for h in handles:
        h.remove()
    handles[0] = draw(renderer, data)[0]
    
class FillBetweenRenderer(MatplotlibRenderer):
    draw_update_map = {
        draw: update
    }

# Next, we need to inform the backend that when a FillBetweenSeries is
# ecountered, it will be rendered with FillBetweenRenderer
MB.renderers_map[FillBetweenSeries] = FillBetweenRenderer

# Next, let's create a function similar to the ones exposed by the plotting module.
# Note that it returns a list of series. Depending on the visualization
# that you are trying to create, you may need more than one data
# series...
def fill_between(expr, range, label="", rendering_kw={}, **kwargs):
    """
    Parameters
    ----------
    expr : Expr
        The symbolic expression
    ramge : (symbol, min, max)
        Initial range in which to plot the expression
    label : str
        Eventual label to be shown on the legend
    rendering_kw : dict
        A dictionary containing keys/values which are going to
        be passed to matplotlib.fill_between.
    **kwargs : 
        Keyword arguments related to `line()`.
    """
    series = [
        FillBetweenSeries(expr, range, label, rendering_kw=rendering_kw, **kwargs)
    ]
    return series

# Finally, we are ready to use it
var("x")
g = graphics(
    fill_between(
        5/x, (x, -10, 10), rendering_kw={"color": "gray", "alpha": 0.3},
        exclude=[0] # exclude this point, where the function is undefined
    ),
    line(
        5/x, (x, -10, 10), rendering_kw={"color": "r", "linestyle": "--"},
        exclude=[0] # exclude this point, where the function is undefined
    ),
    xlabel="x", ylabel="y", title="y > 5/x", show=False, legend=False
)

def _update_axis_limits(event):
    xlim = g.ax.get_xlim()
    ylim = g.ax.get_ylim()
    limits = [xlim, ylim]
    all_params = {}
    for s in g.series:
        new_ranges = []
        for r, l in zip(s.ranges, limits):
            new_ranges.append((r[0], *l))
        s.ranges = new_ranges
        # inform the data series that they must generate new data on next update
        s._interactive_ranges = True
        s.is_interactive = True
        # extract any existing parameters
        all_params = g.merge({}, all_params)
    # create new data and update the plot
    g.update_interactive(all_params)

g.show(block=False)
g.fig.canvas.mpl_connect('button_release_event', _update_axis_limits)

enter image description here