LoginSignup
2
2

論文でよく使うmatplotlib subplotsの覚え書き

Last updated at Posted at 2021-11-03

論文を書く際に複数枚のグラフの内,最も左にあるグラフのみx軸を表記し,最も下にあるグラフのみy軸を表記するという形式をよく利用するので覚え書きします.(ちなみにx軸, y軸内で表示すべき領域が完全に一致している場合はsharex, shareyという選択肢もあります.)
また,左下隅のグラフのみ消すという操作もよく利用するので,customizeの例 (ax.axis('off') によって可能です) に載せています.

from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import matplotlib.pyplot as plt


plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["mathtext.fontset"] = "stix"  # The setting of math font
plt.rcParams["text.usetex"] = True


@dataclass
class GridSpecKeywords:
    """
    Attributes:
        width_ratios, height_ratios (Optional[List[int]]):
            The aspect ratio of each subplot
        hspace, wspace (float):
            The margin between each subplot
    """
    width_ratios: Optional[List[int]] = None
    height_ratios: Optional[List[int]] = None
    wspace: float = 0.025
    hspace: float = 0.05


@dataclass
class TickParams:
    """
    Attributes:
        labelleft, labelbottom (bool):
            Whether to use the tick label for left and bottom axis
        left, bottom (bool):
            Whether to use the tick for left and bottom axis
    """
    labelleft: bool = False
    labelbottom: bool = False
    left: bool = False
    bottom: bool = False


def plot_func(ax: plt.Axes, row: int, col: int, n_rows: int, n_cols: int) -> None:
    # Customize HERE!!!
    if row + 1 == n_rows and col == 0:
        ax.axis('off')
    elif row == 1 and col == 1:
        f = lambda x: 1.0 / np.sqrt(2 * np.pi) * np.exp(- 0.5 * (x - 0.5) ** 2)
        x = np.linspace(-3, 3, 100)
        ax.plot(x, f(x))
    elif row == 0 and col == 0:
        f = lambda x: 1.0 / np.sqrt(2 * np.pi) * np.exp(- 0.5 * (x - 0.5) ** 2)
        x = np.linspace(-3, 3, 100)
        ax.plot(-f(x), x)
    else:
        X = np.random.multivariate_normal(mean=[0.5, -0.5], cov=np.identity(2), size=100)
        ax.scatter(X[:, 0], X[:, 1])
        ax.grid()


def show_plot(n_rows: int, n_cols: int, gs: GridSpecKeywords) -> None:
    tp = TickParams()
    _, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(8, 6), gridspec_kw=gs.__dict__, squeeze=False)

    for row in range(n_rows):
        for col in range(n_cols):
            display_bottom = row + 1 == n_rows
            display_left = col == 0
            tp.labelbottom = tp.bottom = display_bottom
            tp.labelleft = tp.left = display_left

            ax = axes[row][col]
            ax.tick_params(**tp.__dict__)

            plot_func(ax=ax, row=row, col=col, n_rows=n_rows, n_cols=n_cols)

    plt.show()


if __name__ == '__main__':
    # Modify here to change the configuration
    gs = GridSpecKeywords(width_ratios=[1, 7], height_ratios=[7, 1])

    show_plot(n_rows=2, n_cols=2, gs=gs)

なお,legendをグラフ下部に共有したい場合は以下のようにします.

axes[-1][center_col_idx].legend(
    handles=lines,
    loc='upper center',
    labels=labels,
    fontsize=24,
    bbox_to_anchor=(0.5, -0.25),  # ここは調整が必要です
    fancybox=False,
    ncol=len(labels)
)

さらにlabelを共有したい場合は fig.supxlabel or fig.subylabel を利用します.

また,tight_layoutでうまくいかない場合は以下を使います.

plt.savefig(fig_name, bbox_inches='tight')

gridに対して対数Plotをする場合は以下のようにGridを定義します.

ax.grid(which='minor', color='gray', linestyle=':')
ax.grid(which='major', color='black')
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2