はじめに
jupyter 上で画像認識なんてやっていると、たくさんの画像やグラフを並べたくなる事がある。
MNIST の教師画像を全部表示するとか。(本当にやると帰ってこない上に落ちると思う)
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import keras
mnist = keras.datasets.mnist.load_data()
(X_train, t_train), (X_test, t_test) = mnist
f, ax = plt.subplots(3000,20, figsize=(20,3000))
axes = ax.reshape((-1,))
for x, ax in zip(X_train, axes):
ax.imshow(x, cmap=cm.gray)
ax.axis('equal')
ax.axis('off')
plt.show()
でも、グラフや画像を描くたびに matplotlib をちまちま呼び出すのが面倒くさくなった。
そこで、勝手に表示する仕組みを作ってみた。
matplotlib.pyplot.subplots()
画像やグラフを並べる場合、subplots() を使う。
import matplotlib.pyplot as plt
nrows = 3
ncols = 5
f, axes = plt.subplots(nrows, ncols, squeeze=False, figsize=(ncols*2.0, nrows*2.0))
for i in range(13):
r = i // ncols
c = i % ncols
axes[r,c].plot(range(10), range(r+c,10+r+c)) # 適当なグラフ
plt.show()
しかし、subplots() には、
- matplotlib.pyplot.show() を呼ぶまでは表示されない。数が多いと待たされる。
- グラフ・画像の枚数があらかじめ分かっていないといけない。プログラム中でグラフの数が変わったりすると都合が悪い。
- グラフ・画像の縦横のサイズを予め決めないといけない。
- 一度にあまりたくさん表示できない。
という欠点があり、使いやすいとは言えない。
特に、jupyter 上でお手軽に画像を眺めたい時に使いづらい。
1 行ごとに plt.show() すればいいんじゃない?
はじめに思いついたのは、1 行ごとに subplots() と show() を呼ぶ、という方法。
import matplotlib.pyplot as plt
ncols = 5
f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
for i in range(13):
axes[i % 5].plot(range(10), range(i,10+i)) # 適当なグラフ
if i % 5 == 4:
plt.show()
f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
plt.show()
画像を大量に表示する場合に、少しずつ表示されていくので、待たされ感がずいぶんと減る。
1行分のグラフ数と縦横サイズがあればいいので、レイアウトについて考えなくていい。
単純なループに組み込むのであれば、これでもいい。
しかし、axes をチェックする処理がわずらわしい。
ジェネレータがあればいいんじゃない?
面倒な処理を何度も書きたくないので、ジェネレータにしてみた。
import matplotlib.pyplot as plt
def axes_generator(ncols):
while True:
f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
for c in range(ncols):
yield axes[c]
plt.show()
ag = axes_generator(5)
for i, ax in zip(range(13), ag):
ax.plot(range(10), range(i,10+i)) # 適当なグラフ
plt.show()
ジェネレータのおかげで、グラフを描く処理はそれに専念できるようになった。
ループの中がスッキリして読みやすい。
グラフを追加するなら、__next__() を呼べばイテレータが次の Axes を返してくれる。
ag = axes_generator(5)
for i, ax in zip(range(13), ag):
ax.plot(range(10), range(i,10+i)) # 適当なグラフ
if i % 3 == 2:
ax = ag.__next__()
ax.bar(range(5), range(i,i+5)) # 適当なグラフ
plt.show()
でも、余った Axes が見えてしまって格好悪い。
ジェネレータクラスがあればいいんじゃない?
余った分は見えないようにしたいので、クラス化することにした。
- ジェネレータとして使える。
- subplots() で作った Axes をメンバで管理。
- グラフを追加できるように、Axes を取得するメソッドを用意。
- 最後の plt.show() の前に、余った Axes の軸を消して Axes を不可視にした。
import matplotlib.pyplot as plt
class AxesGenerator:
def __init__(self, ncols:int=6, figsize:tuple=None, *args, **kwargs):
self._ncols = ncols
self._figsize = figsize
self._axes = []
def __iter__(self):
while True:
yield self.get()
def get(self):
if len(self._axes) == 0:
plt.show()
f, axes = plt.subplots(nrows=1, ncols=self._ncols, figsize=self._figsize)
self._axes = list(axes) if self._ncols > 1 else [axes,]
ax = self._axes.pop(0)
return ax
def flush(self):
for ax in self._axes:
ax.axis('off')
plt.show()
self._axes = []
ncols = 5
ag = AxesGenerator(ncols, figsize=(ncols*2.0, 2.0))
for i, ax in zip(range(13), ag):
ax.plot(range(10), range(i,10+i)) # 適当なグラフ
if i % 3 == 2:
ax = ag.get()
ax.bar(range(5), range(i,i+5)) # 適当なグラフ
ag.flush()
かなりいい感じになってきた。
ただ、後片付けのため flush() を呼んでいるけれど、間違えて plt.show() にしたり、書き忘れたり、なんかやらかしそう。
後片付けが面倒くさいなら with 文に任せればいいんじゃない?
後片付けなら with 文が使えるはず。
with のため、__enter__() と __exit__() を追加した。
ついでに、コンストラクタで subplots() の引数を受け付けるようにした。
import matplotlib.pyplot as plt
class AxesGenerator:
def __init__(self, ncols:int=6, sharey=False, subplot_kw=None, gridspec_kw=None, **fig_kw):
self._ncols = ncols
self._sharey = sharey
self._subplot_kw = subplot_kw
self._gridspec_kw = gridspec_kw
self._fig_kw = fig_kw
self._axes = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.flush()
return True # 例外処理は省いてある
def __iter__(self):
while True:
yield self.get()
def get(self):
if len(self._axes) == 0:
plt.show()
f, axes = plt.subplots(nrows=1, ncols=self._ncols, sharey=self._sharey, subplot_kw=self._subplot_kw, gridspec_kw=self._gridspec_kw, **self._fig_kw)
self._axes = list(axes) if self._ncols > 1 else [axes,]
ax = self._axes.pop(0)
return ax
def flush(self):
for ax in self._axes:
ax.axis('off')
plt.show()
self._axes = []
ncols = 5
with AxesGenerator(ncols, figsize=(ncols*2.0, 2.0)) as ag:
for i, ax in zip(range(13), ag):
ax.plot(range(10), range(i,10+i)) # 適当なグラフ
if i % 3 == 2:
ax = ag.get()
ax.bar(range(5), range(i,i+5)) # 適当なグラフ
おわりに
たくさんのグラフや画像を手っ取り早く眺める、という目的は達成できた。
今なら、MNIST の教師画像 60,000 枚を全部表示なんてこともできる。(本当にやると途中で落ちると思う)
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import keras
mnist = keras.datasets.mnist.load_data()
(X_train, t_train), (X_test, t_test) = mnist
with AxesGenerator(ncols=20, figsize=(20,1)) as ag:
for x, ax in zip(X_train, ag):
ax.imshow(x, cmap=cm.gray)
ax.axis('equal')
ax.axis('off')
もう少しスマートなコードにできる気もするが、それはまたいつか。