Why is there grid lines for some plots and not for others?

51 Views Asked by At

Consider this MWE,

from matplotlib import pyplot
pyplot.style.use('classic')
pyplot.rcParams.update( {
    'text.usetex': True,
    "font.family": "serif",
    'text.latex.preamble': r'\usepackage{amsmath, amssymb, mlmodern}', } )
import numpy


fig, ax = pyplot.subplots(3,3)
fig.tight_layout()
fig.subplots_adjust(hspace=0.1, wspace=0.1,
    left=0.09, right=.95, top=.95, bottom=.09)

x = numpy.linspace(0, 8, 100)

for i in range(3):
    for j in range(3):
        ax[i,j].plot(x, numpy.sin((1+j)*x+numpy.pi*i), )
        ax[i,j].grid(which='both')  # <----------------- I added grid here
        if i!=2: ax[i,j].set_xticks([])
        if j==1: ax[i,j].set_yticks([])
        if j==2: ax[i,j].yaxis.tick_right()

ax[0,0].set_ylabel('$\phi=0$')
ax[1,0].set_ylabel('$\phi=\pi$')
ax[2,0].set_ylabel('$\phi=2\pi$')

ax[2,0].set_xlabel('$f = 1$')
ax[2,1].set_xlabel('$f = 2$')
ax[2,2].set_xlabel('$f = 3$')

pyplot.savefig('waves.png')

Which produces the following plot, waves with varying freq and phase

I don't understand why matplotlib has grid lines for (0,0) and (0,2) axes, vertical grid lines for axes (0,1), horizontal grid lines for axes (1,0), (1,2), (2,0) and (2,2) and no grid lines for axes (1,1) and (2,1).

How do I make sure that all axes have the grid similar to (0,0)? Thank you.

2

There are 2 best solutions below

8
mozway On BEST ANSWER

Your issue is due to removing the ticks. Instead you could just remove the tick labels.

Replace:

        if i!=2: ax[i,j].set_xticks([])
        if j==1: ax[i,j].set_yticks([])

with:

        if i!=2: ax[i,j].set_xticklabels([])
        if j==1: ax[i,j].set_yticklabels([])

Output:

enter image description here

If you don't need the labels on the right, note that you could automatically hide the inner labels by setting sharex=True, sharey=True:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
fig.tight_layout()
fig.subplots_adjust(hspace=0.1, wspace=0.1,
    left=0.09, right=.95, top=.95, bottom=.09)

x = numpy.linspace(0, 8, 100)

for i in range(3):
    for j in range(3):
        ax[i,j].plot(x, numpy.sin((1+j)*x+numpy.pi*i), )
        ax[i,j].grid(which='both')
0
bona_rivers On

The issue is in this two lines

if i!=2: ax[i,j].set_xticks([])
if j==1: ax[i,j].set_yticks([])

because you are removing the ticks and not just the text. To solve it you need to use ticklabels instead of ticks as:

if i!=2: ax[i,j].set_xticklabels([])
if j==1: ax[i,j].set_yticklabels([])