1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

自分のデータセットの特徴マップを可視化する

Posted at

目標

特徴マップを可視化する

実装

import timm
import os
import matplotlib.pyplot as plt
import torch

def main():
    os.makedirs("result", exist_ok=True)

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # datasetは特徴マップを見てみたいデータセットを使う
    dataset = YourDataset()
    loader = torch.utils.data.DataLoader(dataset, batch_size = 1, shuffle = False, pin_memory=True, num_workers = 1)

    print("train size: ", len(dataset))
    
    # netは特徴マップを見てみたいモデルを使う
    net = YouModel()
    net.load_state_dict(torch.load("your_pretrained_model_path.pth"), strict = True)
    # timmを使っても良い
    # net = timm.create_model(model_name, in_chans = 3, features_only=True, pretrained=True, out_indices = (0, 1, 2, 3, 4))
    net = net.to(DEVICE, non_blocking=True)
    net.eval()

    with torch.no_grad():
        for count, item in enumerate(loader):
            # 素のRGB画像がpre_transform_imageに,pytorchのtensorがimage_tensorに格納されていると仮定
            # 必要であれば,ココを書き換えてもらえればOK
            pre_transform_image, image_tensor = item
            processed = []

            # 素のRGB画像を格納する
            # pytorchのtensorの場合は以下のコメントアウトのようにすれば良いはず
            # processed.append(image_tensor.squeeze(0).permute(1, 2, 0).contiguous().numpy())
            processed.append(pre_transform_image.squeeze(0))

            # 出力した特徴マップは複数あると想定する
            # timmのout_indices = (0, 1, 2, 3, 4)をするように複数の特徴マップを可視化したい場合は有効
            image_tensor = image_tensor.to(DEVICE, non_blocking=True)
            outputs = net(image_tensor)

            for feature_map in outputs:
               # チャンネル方向の平均値を特徴マップとする
               # 必要であれば,チャンネルごとの特徴マップも出力できるが,チャンネル数が512とかあると大変なので,妥協した
               feature_map = feature_map.squeeze(0)
               gray_scale = torch.sum(feature_map, 0)
               gray_scale = gray_scale / feature_map.shape[0]
               print(gray_scale.shape)
               processed.append(gray_scale.cpu().numpy())

            # figsize = ()で画像サイズを適宜変えること
            # figsize = (height, width)であることに注意
            fig = plt.figure(figsize = (320, 180))
            for count_, fm in enumerate(processed):
                a = fig.add_subplot(1, len(processed) + 1, count_ + 1)
                imgplot = plt.imshow(fm)
                a.axis("off")
            
            # 直接確認したければ,下のコードのコメントアウトを外すこと
            # plt.show()
            plt.savefig(f"result/{count}.png")
            plt.clf()
            plt.close()
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?