2nagatomo2
@2nagatomo2

Are you sure you want to delete the question?

If your question is resolved, you may close it.

Leaving a resolved question undeleted may help others!

We hope you find it useful!

動画分類におけるGrad-CAMを用いた判断根拠の推定

解決したいこと

動画分類とその判断根拠の表示をしたいと思っています.
分類モデルはECOを使っています.「作りながら学ぶ!Pytorch発展ディープラーニング」の第9章を参考に自作のデータセットで2値分類と判断根拠の推定ができるようにしたいと考えています.
動画の2値分類まではうまくできたのですが,Grad-CAMをつかったヒートマップ表示ができません.
色々論文やWebサイトを調べているのですが,画像分類でGrad-CAMを使っているものしかヒットせず,動画分類に応用できなくて困っています.

解決策以外にも,参考になる論文,サイトなどを教えていただけるだけでもありがたいです.よろしくおねがいします.

発生している問題・エラー

TypeError: zip argument #1 must support iteration

該当するソースコード

from dataloader2 import make_datapath_list, VideoTransform, get_label_id_dictionary, VideoDataset, CamDataset
from model import ECO_Lite

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import time
import os 
import copy
import cv2

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

batch_size = 3
weight_decay = 0.005
learning_rate = 0.001
num_epochs = 10

#video_listの作成
train_root_path = './data/left-handed/train/'
train_video_list = make_datapath_list(train_root_path)

val_root_path = './data/left-handed/val/'
val_video_list = make_datapath_list(val_root_path)

#前処理の設定
resize, crop_size = 224, 224
mean, std = [104, 117, 123], [1, 1, 1]
video_transform = VideoTransform(resize, crop_size, mean, std)

#ラベル辞書の作成
label_dictionary_path = './data/pitching.csv'
label_id_dict, id_label_dict = get_label_id_dictionary(label_dictionary_path)

#Datasetの作成(画像)
train_dataset = VideoDataset(train_video_list, label_id_dict, num_segments=16,
                             phase="train", transform=video_transform,
                             img_tmpl='image_{:05d}.jpg')

val_dataset = VideoDataset(val_video_list, label_id_dict, num_segments=16,
                             phase="val", transform=video_transform,
                             img_tmpl='image_{:05d}.jpg')

cam_dataset = CamDataset(val_video_list, label_id_dict, num_segments=16,
                             phase="val", transform=video_transform,
                             img_tmpl='image_{:05d}.jpg')

#DataLoaderにする
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle = True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle = False)

cam_dataloader = torch.utils.data.DataLoader(
    cam_dataset, batch_size=batch_size, shuffle = False)

#辞書型変数にまとめる
dataloaders_dict = {"train":train_dataloader, "val":val_dataloader}
train_steps = len(train_dataloader.dataset) // batch_size
val_steps = len(val_dataloader.dataset) // batch_size
steps = {"train":train_steps, 'val':val_steps}

#Networkの準備,lass関数,最適化手法
device = torch.device("cuda")
net = ECO_Lite()
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

cam_iterator = iter(cam_dataloader)
imgs_transformeds, labels, label_ids, dir_path, indices, img_tmpl = next(cam_iterator)

'''画像をまとめて読み込み、リスト化する関数'''
img_group = []  # 画像を格納するリスト

def _load_imgs(dir_path, img_tmpl, indices):
    '''画像をまとめて読み込み、リスト化する関数'''
    img_group = []  # 画像を格納するリスト
    file_path_list = []

    for i in range(16):
        # 画像のパスを取得
        file_path = os.path.join(dir_path[0], img_tmpl[0].format(indices[0, i].item()))
        file_path_list.append(file_path)

        # 画像を読み込む
        img = Image.open(file_path).convert('RGB')

        # リストに追加
        img_group.append(img)
    return img_group, file_path_list

img_group, file_path_list = _load_imgs(dir_path, img_tmpl, indices)

pitching_img = []
for i in range(len(file_path_list)):
    pitching_img.append(cv2.imread(file_path_list[i]))
    pitching_img[i] = cv2.cvtColor(pitching_img[i], cv2.COLOR_BGR2RGB)

target_layers = [net]
cam = GradCAM(model = net, target_layers = target_layers, use_cuda = torch.cuda.is_available())

input_tensor = imgs_transformeds #(batch_size, channel, height, width)
vis_image = cv2.resize(pitching_img[0], (256, 256)) / 255.0 #(height, width, channel), [0, 1]
label = [ClassifierOutputTarget(0)]

print(input_tensor.shape, vis_image.shape)

grayscale_cam = cam(input_tensor = input_tensor, targets = label_ids)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(vis_image, grayscale_cam, use_rgb = True)
plt.imshow(visualization)
plt.show()

非常にわかりにくいかと思いますが何卒よろしくお願いいたします.

0

No Answers yet.

Your answer might help someone💌