論文を書く際に複数枚のグラフの内,最も左にあるグラフのみx軸を表記し,最も下にあるグラフのみy軸を表記するという形式をよく利用するので覚え書きします.(ちなみにx軸, y軸内で表示すべき領域が完全に一致している場合はsharex
, sharey
という選択肢もあります.)
また,左下隅のグラフのみ消すという操作もよく利用するので,customizeの例 (ax.axis("off")
によって可能です) に載せています.
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
# Type 3 --> Type 1 (Many venues require Type 1 font.)
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
# My Preferences.
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["mathtext.fontset"] = "stix" # The setting of math font
plt.rcParams["text.usetex"] = True
@dataclass
class GridSpecKeywords:
"""
Attributes:
width_ratios, height_ratios (list[int] | None):
The aspect ratio of each subplot
hspace, wspace (float):
The margin between each subplot
"""
width_ratios: list[int] | None = None
height_ratios: list[int] | None = None
wspace: float = 0.025
hspace: float = 0.05
@dataclass
class TickParams:
"""
Attributes:
labelleft, labelbottom (bool):
Whether to use the tick label for left and bottom axis
left, bottom (bool):
Whether to use the tick for left and bottom axis
"""
labelleft: bool = False
labelbottom: bool = False
left: bool = False
bottom: bool = False
def plot_func(ax: plt.Axes, row: int, col: int, n_rows: int, n_cols: int) -> None:
# Customize HERE!!!
if row + 1 == n_rows and col == 0:
ax.axis("off")
elif row == 1 and col == 1:
f = lambda x: 1.0 / np.sqrt(2 * np.pi) * np.exp(- 0.5 * (x - 0.5) ** 2)
x = np.linspace(-3, 3, 100)
ax.plot(x, f(x))
elif row == 0 and col == 0:
f = lambda x: 1.0 / np.sqrt(2 * np.pi) * np.exp(- 0.5 * (x - 0.5) ** 2)
x = np.linspace(-3, 3, 100)
ax.plot(-f(x), x)
else:
X = np.random.multivariate_normal(mean=[0.5, -0.5], cov=np.identity(2), size=100)
ax.scatter(X[:, 0], X[:, 1])
ax.grid()
def show_plot(n_rows: int, n_cols: int, gs: GridSpecKeywords) -> None:
tp = TickParams()
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(8, 6), gridspec_kw=gs.__dict__, squeeze=False)
for row in range(n_rows):
for col in range(n_cols):
display_bottom = row + 1 == n_rows
display_left = col == 0
tp.labelbottom = tp.bottom = display_bottom
tp.labelleft = tp.left = display_left
ax = axes[row][col]
ax.tick_params(**tp.__dict__)
plot_func(ax=ax, row=row, col=col, n_rows=n_rows, n_cols=n_cols)
plt.show()
if __name__ == "__main__":
# Modify here to change the configuration
gs = GridSpecKeywords(width_ratios=[1, 7], height_ratios=[7, 1])
show_plot(n_rows=2, n_cols=2, gs=gs)
なお,legendをグラフ下部に共有したい場合は以下のようにします.
fig.legend(
handles=lines,
loc="upper center",
labels=labels,
fontsize=24,
bbox_to_anchor=(0.5, -0.25), # ここは調整が必要です
fancybox=False,
ncol=len(labels)
)
さらにlabelを共有したい場合は fig.supxlabel
or fig.subylabel
を利用します.
また,tight_layout
でうまくいかない場合は以下を使います.
plt.savefig(fig_name, bbox_inches="tight")
gridに対して対数Plotをする場合は以下のようにGridを定義します.
ax.grid(which="minor", color="gray", linestyle=":")
ax.grid(which="major", color="black")