Python generate multiple subplots#

Software requirements:

  • Python 3

  • numpy

  • matplotlib

  • cartopy

Example script#

multi_subplots_grid.py

#!/usr/bin/env python
# coding: utf-8
'''
DKRZ example

Multiple subplots

This notebook demonstrates how to generate a 'panel plot' with 7 plots drawn
in a 3x3 subplots grid. Add a common colorbar in the second last axis, and
suppress the plotting of the last axis.

Content

- generate random data of shape (7, 40, 20)
- draw the data in a subplots grid 3 rows x 3 columns
- add a common title
- add a common colorbar to second to last axis
- don't draw last axis
- save to PNG

-------------------------------------------------------------------------------
2024 copyright DKRZ licensed under CC BY-NC-SA 4.0 <br>
               (https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en)
-------------------------------------------------------------------------------
'''
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colorbar as colorbar
import cartopy.crs as ccrs

def main():
    plt.switch_backend('agg')

    #-- Generate random data
    vmin, vmax = 0, 100
    nt, nx, ny = 7, 100, 50
    
    rng = np.random.default_rng(seed=42)
    data = np.random.randint(vmin, vmax, size=(nt, ny, nx))
    
    #-- Let's assume that the random data is in the European area.
    lon = np.linspace(-40.375, 75.375, nx)
    lat = np.linspace(25.375, 75.375, ny)
    
    #-- Set the map projection and data transformation.
    projection = ccrs.TransverseMercator(central_longitude=11.,
                                         central_latitude=50.)
    transform = ccrs.PlateCarree()
    
    #-- Set the number of rows and columns for the subplots grid.
    nrows, ncols  = 3, 3
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8,8),
                             constrained_layout=True,
                             subplot_kw=dict(projection=projection))
    #-- common title
    fig.suptitle(f'Random data: shape ({nt}, {nx}, {ny})', fontsize=16, weight='bold')
    
    for i, ax in enumerate(axes.flat):
        #-- generate the axis plot
        if i < data.shape[0]:
            ax.set_title(f'data {i}', fontsize=8)
            ax.set_extent([lon.min()+31., lon.max()-35., lat.min()+6., lat.max()-5])
            ax.coastlines(resolution='50m', lw=1)
            plot = ax.pcolormesh(lon, lat, data[i,:,:], cmap='Blues', alpha=0.25,
                                 vmin=vmin, vmax=vmax, transform=transform)
    
        #-- create the common colorbar in the second last axis
        if i == (nrows*ncols)-2:
            ax.set_visible(False)
            bbox = ax.get_position()
            x, y, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
            cax = fig.add_axes([x, y-0.045, w/14, h-0.03], autoscalex_on=True)
            cbar = colorbar.Colorbar(cax, orientation='vertical',
                                     cmap='Blues', alpha=0.25,
                                     norm=plt.Normalize(vmin, vmax))
        #-- don't draw last axis
        if i == (nrows*ncols)-1:
            ax.set_visible(False)
    
    #-- save figure to PNG file
    fig.savefig('plot_multi_subplots_example.png', bbox_inches='tight', facecolor='white')
    

if __name__ == '__main__':
    main()


Plot result:

image0