目標
特徴マップを可視化する
実装
import timm
import os
import matplotlib.pyplot as plt
import torch
def main():
os.makedirs("result", exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# datasetは特徴マップを見てみたいデータセットを使う
dataset = YourDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size = 1, shuffle = False, pin_memory=True, num_workers = 1)
print("train size: ", len(dataset))
# netは特徴マップを見てみたいモデルを使う
net = YouModel()
net.load_state_dict(torch.load("your_pretrained_model_path.pth"), strict = True)
# timmを使っても良い
# net = timm.create_model(model_name, in_chans = 3, features_only=True, pretrained=True, out_indices = (0, 1, 2, 3, 4))
net = net.to(DEVICE, non_blocking=True)
net.eval()
with torch.no_grad():
for count, item in enumerate(loader):
# 素のRGB画像がpre_transform_imageに,pytorchのtensorがimage_tensorに格納されていると仮定
# 必要であれば,ココを書き換えてもらえればOK
pre_transform_image, image_tensor = item
processed = []
# 素のRGB画像を格納する
# pytorchのtensorの場合は以下のコメントアウトのようにすれば良いはず
# processed.append(image_tensor.squeeze(0).permute(1, 2, 0).contiguous().numpy())
processed.append(pre_transform_image.squeeze(0))
# 出力した特徴マップは複数あると想定する
# timmのout_indices = (0, 1, 2, 3, 4)をするように複数の特徴マップを可視化したい場合は有効
image_tensor = image_tensor.to(DEVICE, non_blocking=True)
outputs = net(image_tensor)
for feature_map in outputs:
# チャンネル方向の平均値を特徴マップとする
# 必要であれば,チャンネルごとの特徴マップも出力できるが,チャンネル数が512とかあると大変なので,妥協した
feature_map = feature_map.squeeze(0)
gray_scale = torch.sum(feature_map, 0)
gray_scale = gray_scale / feature_map.shape[0]
print(gray_scale.shape)
processed.append(gray_scale.cpu().numpy())
# figsize = ()で画像サイズを適宜変えること
# figsize = (height, width)であることに注意
fig = plt.figure(figsize = (320, 180))
for count_, fm in enumerate(processed):
a = fig.add_subplot(1, len(processed) + 1, count_ + 1)
imgplot = plt.imshow(fm)
a.axis("off")
# 直接確認したければ,下のコードのコメントアウトを外すこと
# plt.show()
plt.savefig(f"result/{count}.png")
plt.clf()
plt.close()