CNNの特徴マップが各層でどんな出力になっているかを可視化してみました。
入力画像にはこちらの画像を使います。
コード全体
pytorchを用いてCNNを実装します。
ここでは、efficientnet_b0の学習済みモデルを使用します。
フォワードフックという機能を使うことで、モデルの中間出力を取り出すことができます。
import torch
import timm
import matplotlib.pyplot as plt
import cv2
import numpy as np
# 1. モデルを取得
model = timm.create_model('efficientnet_b0', pretrained=True)
model.eval() # 評価モードに設定
# 2. 特徴マップを保存するリストを準備
feature_maps = []
# 3. フォワードフックを作成
def hook_fn(module, input, output):
feature_maps.append(output)
# 4. 各レイヤーにフックを登録
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.Conv2d): # Conv2dレイヤーの出力を可視化
layer.register_forward_hook(hook_fn)
# 5. 入力画像を準備
image = cv2.imread("flower.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).unsqueeze(0).float()
# 6. フォワードパスで特徴マップを取得
with torch.no_grad():
_ = model(image)
# 7. 特徴マップを可視化
def visualize_feature_maps(feature_maps, layer_idx=0):
fmap = feature_maps[layer_idx].squeeze(0) # バッチ次元を削除
num_filters = fmap.shape[0] # フィルタの数を取得
plt.figure(figsize=(20, 20))
for i in range(min(num_filters, 36)): # 最初の36フィルタを可視化
plt.subplot(6, 6, i+1)
plt.imshow(fmap[i].cpu().numpy(), cmap='viridis')
plt.axis('off')
plt.show()
# 8. 最初のConv層の特徴マップを可視化
visualize_feature_maps(feature_maps, layer_idx=0)
結果
ご覧の通り、出力層に近いところの方がより抽象的な特徴マップを得られています。
直感的には、微細な特徴が重要となるタスクでは入力層に近い特徴マップをうまく使うことがモデルの精度に影響してくるのかなぁ、と想像してます。