概要
個人的な備忘録を兼ねたPyTorchの基本的な解説とまとめです。第4回画像分類の続きで4.5回目となります1。
第4回 はMNISTに類似したグレースケールの「あいうえお」画像によって画像分類の初歩を行いました2。今回は完全な続きでCNNの出力やカーネルの可視化からスタートします。
演習用のファイル(Githubからダウンロード)3
1. 第4回目の要約
ひらがな「あ・い・う・え・お」の5種類を手書き文字画像を利用した画像分類の演習を行いました。ネットワーク構造は下図のような畳み込み層と全結合層を利用した単純なものでした。
今回の目的は特徴量を画像を表示することです。
今回の目的
- 畳み込み層の出力結果を画像として表示
- 畳み込み層のカーネルを画像として表示
2. 画像表示コードと解説
2.1 画像データの表示
最初に読み込んだ画像データの表示から行います。matplotlibで画像を表示させるimshow()は、(高さ、幅)もしくは(高さ,幅,チャンネル数)の形になっている必要があります。読み込んだ画像データのチャンネル数の軸を削除する方法を採用します。
import numpy as np # データの読み込みに利用
import matplotlib.pyplot as plt # 画像表示
data = np.load("./aiueo.npz") # aiueo.npzの読み込み
x = data["x"] # "x"キーで画像データにアクセス (792,1,28,28)
img = x[0] # チャンネルの部分を削除して行列の形に変換
plt.imshow(img, cmap="gray_r") # 画像表示, gray_r: 白黒反転
plt.show()
imshow()のカラーオプションをcmap="gray_r"
としてあるので白黒が反転した32ピクセル×32ピクセルの画像が表示されます.
通常のグレースケールの画像cmap="gray"
だとこんな感じになります。
2.2 CNNの出力結果の表示
forwardの戻り値として準備してある、予測値(y)・CNNの出力値(h1)・活性化関数後の出力値(h2)の3種類を活用します。
コード内のprediction, feature_map, act_feature_map = model(sample_image)
が予測値、CNN出力値、活性化関数後の出力値となります。CNNの出力値はマイナスの値になることがありますが、imshow()によって適切な範囲に正規化して画像を表示してくれます。
import matplotlib.pyplot as plt
import japanize_matplotlib
# テスト用の一つの画像を選択
# あ:300, い:500 う:200 え:600 お:100
sample_image = x[200].unsqueeze(0) # バッチ次元を追加 (1, 1, 28, 28)
# モデルを評価モードに
model.eval()
# 推論実行
with torch.no_grad():
prediction, feature_map, act_feature_map = model(sample_image)
# 特徴マップのサイズは (1, 1, 24, 24) のはず
# バッチとチャネル次元を削除して2D画像として表示
feature_map = feature_map.squeeze().cpu().numpy() # (24, 24)の2D配列に変換
#---------------- 画像表示 ------------------
# out_channels=5なので5枚画像が出力される
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
plt.suptitle("CNN出力 (特徴マップ)")
for i, ax in enumerate(axes.flatten()):
ax.imshow(feature_map[i], cmap="gray")
ax.axis('off')
plt.tight_layout()
plt.show()
CNN出力画像やカーネルの画像について若干注意があります。演習用のファイルを実行しても記事と同一の画像にならない可能性があります。解釈などは適宜変更してください
CNNの出力画像
画像の白い部分が注目している部分となります。輪郭の検出、右上がりの線、上下の線、左右の線などに注目するなどの効果がなんとなく見てわかる?ということはなさそうです
LeakyReluのあとの画像
活性化関数を作用させたあとのCNN特徴量を表示させてみました。CNNの出力値がマイナス部分を中心に黒っぽくなるります。先程より少しだけ注目している部分がはっきり見えてくると思われます。上下の線、斜めの線、輪郭など注目している箇所が少しずつ異なっているようにも感じます。
fig, axes = plt.subplots(1, 5, figsize=(10,3))
plt.suptitle("act cnn出力")
for i, ax in enumerate(axes):
ax.imshow(act_feature_map[i], cmap="gray")
ax.axis('off')
plt.tight_layout()
plt.show()
2.3 カーネルの表示
CNNで学習されるカーネルについても画像として表示してみましょう。カーネルの各要素の値は、畳み込み計算時の対応する入力ピクセルへの重み付けを表します。正の大きな値(白っぽい色)はその位置の特徴を強調し、負の大きな値(黒っぽい色)はその位置の特徴を抑制します。カーネルは特定のパターンや特徴に反応するフィルターとして機能し、学習によってさまざまな視覚的特徴(エッジ、テクスチャ、形状など)を検出するものと解釈されています。
ネットワーク名cnn1の重みがカーネルなので、model.cnn1.weitht.detach()
でその重みを取り出すことができます。out_channels=5からカーネルの枚数は5枚となります。カーネルサイズが5x5を勘案すると、コード内の変数kernelsの形状は(5, 1, 5, 5)となります。
# カーネルの重みを取得
kernels = model.cnn1.weight.detach().cpu().numpy() # 形状は (5, 1, 5, 5)
fig, axes = plt.subplots(1, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flatten()):
ax.imshow(kernels[i].squeeze(), cmap='gray')
ax.set_title(f'カーネル{i+1}枚目')
ax.axis('off')
plt.tight_layout()
plt.show()
- 1枚目は右下がりに反応するように見えます。
- 2枚目はちょっとバラバラでよくわからんぞ〜。
- 3枚目は右下がりには反応しない。それ以外に反応する形かな?
- 4枚目は上下に反応?右下がりにはあまり反応せずかな?
- 5枚目は左右の線に反応するよう見える。
2.4 画像表示について
matplotlibを利用して複数枚の画像を表示する方法はいくつかあります。今回の演習では、plot.subplots
で複数枚の画像配置位置を決めて、forループで順番に埋めていく形にしました。今回は1行5列なので効果が見られませんが、axes.flatten()
によって多次元の配置(例えば2行3列など)でも簡単にループできるようになっています。画像をタイル状に並べる方法はいくつもあるので好みの方法を利用してもらえると幸いです。
fig, axes = plt.subplots(1, 5, figsize=(10,3))
for i, ax in enumerate(axes.flatten()):
...
次回
画像分類のネットワークはCNN(Convolutional Neural Network)と一緒にプーリング層もよく使われます。第5回はプーリング層を加えた構造で画像分類問題を行い、特徴量を可視化してみたいと思います。
注
-
4回目が長くなったので2回に分けてしまいました
↩
-
独立行政法人産業技術総合研究所のETL文字データベースを利用させていただきました。「あ・い・う・え・お」の5種類の画像を抽出しMNISTと対応させるべく28x28サイズに縮小したものを利用しています。 ↩
-
演習用のファイルを実行しても記事と同一の画像表示になるとは限りません。ご注意ください。 ↩