1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ソーセージドックとソーセージを見分けてみた。そんで、AIの判断根拠を可視化してみた!

Posted at

ソーセージドックとソーセージをYOLOv8sで見分けてみた。

各クラス、15枚くらい。mAP(精度)は、0.91と良好。学習時間は20分ぐらい。AIの判断根拠を可視化するプログラム(EiganCAM)を使って、AIの判断根拠を学習回数毎に示し(EiganCAM解析)、動画にしてみた。

EigenCAMの解説

EigenCAMは、畳み込みニューラルネットワーク(CNN)に基づくモデルが入力画像のどの部分を重視しているかを視覚化する手法です。以下がEigenCAMの特徴です:

  • 目的:モデルの予測時に注目した画像の領域を特定し、視覚的に理解しやすい「ヒートマップ」として出力します。
  • 仕組み:モデルの特定の層(通常は後半の畳み込み層)を通してヒートマップを生成し、画像上に重ね合わせることで注目領域を示します。
  • 用途:画像分類や物体検出などの結果を、モデルがどの部分に着目して判断しているのか視覚的に確認するのに役立ちます。

これにより、モデルの判断の根拠を確認しやすくなり、AIの「ブラックボックス」部分の理解が進む点が特徴です。

学習

実行は、Colab。Google driveと連携した。学習、EiganCAM解析どちらもGPU必須。
学習のコード。epoch毎のモデルを出力する。

学習
!pip install ultralytics==8.2.103 -q
from IPython import display
display.clear_output()
import ultralytics
ultralytics.checks()

#dataset フォルダーに移動
!mkdir -p /yourdrive_dataset
%cd yourdrive_dataset

!pip install roboflow==1.1.48 --quiet
from roboflow import Roboflow
rf = Roboflow(api_key="-your roboflow api-")
project = rf.workspace("sample-rhjfd").project("sausagedog")
version = project.version(2)
dataset = version.download("yolov8")

#save_periodでモデルの重みを出力し、保存
!yolo task=detect mode=train model=yolov8s.pt data=yourdataset/data.yaml epochs=250 imgsz=800 plots=True save_period=1 save=True

上のコードにあるRoboflow API取得方法はこちら
mAP(精度)は、0.91と良好。学習時間は20分ぐらい。

image.png

SausageをAI君はSausagedogと認識。mAPの割には、意外と間違えている。
image.png

EiganCAM解析を実行

ディレクトリにあるモデルを一括して、EiganCAM解析を行う。

解析
!pip install ultralytics==8.2.103 torch opencv-python matplotlib requests pillow numpy torchvision ttach
!git clone https://github.com/rigvedrs/YOLO-V8-CAM.git
%cd /content/YOLO-V8-CAM
from google.colab import drive
drive.mount('/content/drive')
import torch
from matplotlib import pyplot as plt
from ultralytics import YOLO  # YOLOv8ライブラリ
from yolo_cam.eigen_cam import EigenCAM
from yolo_cam.utils.image import preprocess_image, show_cam_on_image
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import cv2
import numpy as np
import requests
import torchvision.transforms as transforms
from PIL import Image
import io
import os
import re
from datetime import datetime

def ensure_directory(directory):
    """ディレクトリが存在しない場合は作成する"""
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created directory: {directory}")

def get_epoch_number(weight_file):
    """重みファイルからエポック番号を抽出する"""
    # epoch{数字}.pt または best_{数字}.pt の形式に対応
    match = re.search(r'(?:epoch|best_)(\d+)\.pt', weight_file)
    if match:
        return int(match.group(1))
    return None

def generate_eigencam_heatmap(model_path, image_path, output_dir):
    """EigenCAMを使用してヒートマップを生成する"""
    try:
        # ファイルの存在確認
        if not os.path.exists(model_path):
            print(f"Weight file not found: {model_path}")
            return None

        # 出力ディレクトリの作成
        ensure_directory(output_dir)

        # 入力画像の読み込みと前処理
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (640, 640))
        img_normalized = img.astype(np.float32) / 255.0

        # モデルのロード
        model = YOLO(model_path)

        # EigenCAMの設定
        target_layers = [model.model.model[-4]]
        cam = EigenCAM(model, target_layers, task='od')

        # ヒートマップの生成
        grayscale_cam = cam(img_normalized)[0, :, :]

        # サイズの確認とリサイズ
        if grayscale_cam.shape[:2] != (640, 640):
            grayscale_cam = cv2.resize(grayscale_cam, (640, 640))

        # ヒートマップの可視化
        cam_image = show_cam_on_image(img_normalized, grayscale_cam, use_rgb=True)

        # エポック番号の取得とファイル名の生成
# エポック番号の取得とファイル名の生成
        epoch_num = get_epoch_number(os.path.basename(model_path))
        if epoch_num is None:
            output_filename = f'heatmap_unknown_epoch.png'
        else:
            output_filename = f'heatmap_epoch_{epoch_num}.png'

        output_path = os.path.join(output_dir, output_filename)
        # OpenCVを使用して直接画像を保存
        cam_image_bgr = cv2.cvtColor((cam_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
        cv2.imwrite(output_path, cam_image_bgr)

        print(f"Generated heatmap saved to: {output_path}")
        return output_path

    except Exception as e:
        print(f"Error generating heatmap for {model_path}: {str(e)}")
        return None

def process_all_weights(weights_dir, image_path, output_dir):
    """すべての重みファイルに対してヒートマップを生成する"""
    # 重みファイルの一覧を取得してソート
    weight_files = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
    weight_files.sort(key=lambda x: get_epoch_number(x) or float('inf'))

    successful_maps = []
    failed_maps = []

    for weight_file in weight_files:
        weight_path = os.path.join(weights_dir, weight_file)
        result = generate_eigencam_heatmap(weight_path, image_path, output_dir)

        if result is not None:
            successful_maps.append(weight_file)
        else:
            failed_maps.append(weight_file)
            print(f"Skipping to next weight file...")
            continue

    # 結果の表示
    print("\nProcessing complete!")
    print(f"Successfully processed {len(successful_maps)} weight files")
    print(f"Failed to process {len(failed_maps)} weight files")
    if failed_maps:
        print("\nFailed weight files:")
        for f in failed_maps:
            print(f"- {f}")

# メイン処理
if __name__ == "__main__":
    weights_dir = 'モデルがあるディレクトリ'
    image_path = 'EiganCAM解析したい画像のパス'
    output_dir = '出力したいディレクトリ名を入力'

    process_all_weights(weights_dir, image_path, output_dir)

Pythonコードの説明

このPythonコードは、YOLOv8とEigenCAMを用いて一連の重みファイル(モデルの異なるエポックバージョン)からヒートマップを生成するスクリプトです。コードの各関数の役割は以下の通りです:

  1. ensure_directory(directory)

    • 指定したディレクトリが存在しない場合に新規作成します。
  2. get_epoch_number(weight_file)

    • 重みファイル名からエポック番号("epoch"や"best"を含む部分)を抽出します。
  3. generate_eigencam_heatmap(model_path, image_path, output_dir)

    • 指定された重みファイルと画像を用いて、EigenCAMでヒートマップを生成し、出力ディレクトリに保存します。
    • モデルと画像の前処理、EigenCAMの適用、ヒートマップのリサイズと保存までを行います。
  4. process_all_weights(weights_dir, image_path, output_dir)

    • 指定したディレクトリ内の全ての重みファイルに対してgenerate_eigencam_heatmapを適用し、成功したものと失敗したものをリスト化して結果を表示します。

動画にする

EiganCAM解析により出力されたヒートマップらを動画に作成する(fps10)

動画
import cv2
import os
import re
import numpy as np

def create_video_from_heatmaps(heatmap_dir, output_video_path, fps=2):
    """ヒートマップ画像からビデオを作成する"""
    try:
        # ヒートマップ画像のファイル一覧を取得
        image_files = [f for f in os.listdir(heatmap_dir) if f.startswith('heatmap_epoch_') and f.endswith('.png')]

        # エポック番号でソート
        def get_epoch_num(filename):
            match = re.search(r'heatmap_epoch_(\d+)\.png', filename)
            return int(match.group(1)) if match else float('inf')

        image_files.sort(key=get_epoch_num)

        if not image_files:
            print("No heatmap images found!")
            return False

        # 最初の画像を読み込んでビデオの設定を取得
        first_image = cv2.imread(os.path.join(heatmap_dir, image_files[0]))
        height, width, layers = first_image.shape

        # VideoWriterの設定
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

        # 各画像をビデオに追加
        for image_file in image_files:
            image_path = os.path.join(heatmap_dir, image_file)
            frame = cv2.imread(image_path)
            video.write(frame)
            print(f"Added frame: {image_file}")

        # ビデオを閉じる
        video.release()
        print(f"\nVideo created successfully: {output_video_path}")
        return True

    except Exception as e:
        print(f"Error creating video: {str(e)}")
        return False

# メイン処理
if __name__ == "__main__":
    # 入力ディレクトリと出力ファイルのパスを設定
    heatmap_dir = 'EiganCAM解析によるヒートマップが存在するディレクトリ名'
    output_video = '出力したい動画名とそのパス'

    # ビデオを作成。ここでは、fpsを10に設定した。
    create_video_from_heatmaps(heatmap_dir, output_video, fps=10)

<結果>

仮説

勾配の最小化(損失関数の最適化)は2次元上ではなく、3次元(以上?)で行われているため、epochが増える度に、判断根拠が段々と明確に表れるわけではない。
⇄学習が進むにつれて、損失関数が減少することと矛盾。この仮説は棄却。

実装でわかったこと

・Claude Opusが肌感覚では強い。ただ、高い。少し遅い。
・チャットを更新していくたびに、AIが「何をやりたいのか」メモリを更新していくのが感じた。
・関数って、奥が深い!
・XAI(AIの判断根拠を教えるプログラム)はまだまだ、これから。
・XAIを解析するAIがあっても、面白いと思った(maybe maybe coming soon!)

実装で注意したこと

確実に、epoch毎のヒートマップの順で動画になるようにしたこと。

参考にしたプログラム

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?