Abstract
matplotlibって便利ですよね。私も様々な場面で活用しています。
本記事では不連続なグラフの図を作るためのtipsを紹介しようと思います。(意外とこの知見がネットにまとまって落ちていなかったので…)
問題設定
本記事では次のグラフを作成します。
y = \mathrm{sign}(\sin(x))
ただし$\mathrm{sign}(x)$は符号関数であり、次のように定義されます。
\mathrm{sign}(x) = \begin{cases}
1 & \text{if $x > 0$}\\
0 & \text{if $x = 0$}\\
-1 & \text{if $x < 0$}
\end{cases}
まず普通にプロットすると次のようになります。
※以降もそうですが、薄い線は$\sin(x)$です。
ソースコードはこちら
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-np.pi, np.pi, 1000)
y = np.sign(np.sin(x))
plt.plot(x, y)
plt.plot(x, np.sin(x), alpha=0.2)
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
labels=[r"$- \pi$", r"$- \frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"])
plt.yticks([-1, 0, 1])
plt.show()
matplotlibは当然ですが、不連続なんて認識しません。
認識して勝手に線を消す方が厄介まであります。
解法1
matplotlibはnp.inf, np.nan
といった値を無視して描画しません。それを利用したのが次の図とコードです。
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-np.pi, np.pi, 1000)
dx = ( x[-1] - x[0] ) / 1000
y = np.sign(np.sin(x))
# 不連続な点付近を`np.nan`で上書きすることで、グラフから消す。
y[np.gradient(y, dx) > 1] = np.nan
plt.plot(x, y)
plt.plot(x, np.sin(x), alpha=0.2)
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
labels=[r"$- \pi$", r"$- \frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"])
plt.yticks([-1, 0, 1])
plt.show()
今回の関数は$\sin(x)$が0になるタイミングで不連続になるので、$\sin(x) = 0$となるような点を描画から除外します。解析的には$x = -\pi, 0, \pi$ですが、この値がプログラム中のx
に含まれているとは限りません。なのでここでは勾配が一定以上になったら除外といった形で対応しました。
そこまで悪くはないですが、やはり含まない点は◯で明示したいです。
なお$y = 1 / x$といった無限大に発散するパターンでは同様の手法で対応できます。
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 1000)
y = 1 / x
# 不連続な点付近を`np.nan`で上書きすることで、グラフから消す。
threshold = 100
y[y < -threshold] = np.nan
y[y > threshold] = np.nan
plt.plot(x, y)
plt.show()
ここでは適当な値を閾値として、その値より大きい・小さいものnp.nan
で上書きしています。
もちろん、勾配から判定してもよいです。
解法2
今回の関数は次のように明確に書くことができます。
\mathrm{sign}(\sin(x)) = \begin{cases}
1 & \text{if $0 < x < \pi$}\\
0 & \text{if $x = -\pi, 0, \pi$}\\
-1 & \text{if $-\pi < x < 0$}
\end{cases}
なのでこれを丁寧に描画します。
そしてfacecolor(s)
というオプションで◯を再現します。
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-np.pi, np.pi, 1000)
LINE_COLOR = "blue"
FACE_COLOR = plt.rcParams['axes.facecolor'] # 現在のAxesの背景色を取得
exclude_point_options = dict(
facecolors=FACE_COLOR, # これで点の塗りつぶしが背景色になります
edgecolors=LINE_COLOR, # 点の線の色です
s=25, # 調整してください
zorder=2 # 他の線と被ったときに上側に描画してもらうためです
)
include_point_options = dict(
facecolors=LINE_COLOR,
edgecolors=LINE_COLOR, # 点の線の色です
s=exclude_point_options["s"],
zorder=exclude_point_options["zorder"]
)
for k in range(2):
segment_length = np.pi
x_seg = np.linspace(x[0] + k*segment_length, x[0] + (k+1)*segment_length, 1000)
y_seg = np.ones_like(x_seg) * (-1)**(k+1)
plt.plot(x_seg, y_seg, color=LINE_COLOR, zorder=exclude_point_options["zorder"] - 1) # 端っこの点が線の上になるようにzorderを調整
plt.scatter([x_seg[0], x_seg[-1]], [y_seg[0], y_seg[-1]], **exclude_point_options)
zeros = np.array([- np.pi, 0, np.pi])
plt.scatter(zeros, np.zeros_like(zeros), **include_point_options)
plt.plot(x, np.sin(x), alpha=0.2)
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi],
labels=[r"$- \pi$", r"$- \frac{\pi}{2}$", r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"])
plt.yticks([-1, 0, 1])
plt.show()
重要になるのは
- 色の固定
-
zorder
を設定 -
facecolors, edgecolors
を設定
になります。
色の固定をしておかないと、自動で色が割り振られてしまい区間ごとに色が変わってしまいます。
またzorder
を点が大きくなるように設定しないと、線が被ってしまいます。
facecolors, edgecolors
を設定することで、意図した色に設定しています。
なおfacecolors = "none"
とすると透明になります。透明にしてしまうと線が被ってしまうので今回は背景の色に合わせるという形で調整しています。
参考