3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

EfficientnetNetをファインチューニングしてMWAMのメンバー画像分類

Last updated at Posted at 2024-02-29

はじめに

Qiita初投稿です。
大学院では深層学習を用いた研究を行っています。研究で自然言語処理や画像生成などを扱っているうちに、何か面白いもの作れないかなと思いぱっと作ってみたモデルについて記事にまとめます。具体的には、画像認識モデルであるEfficientNetをファインチューニングして画像分類を行いました。

コードや記事の書き方でご指摘などあればよろしくお願いします。

MWAMとは?

MWAM(MAN WITH A MISSION)とは日本の5人組ロックバンドで、天才生物学者であるジミー・ヘンドリックス博士によるマッドサイエンスの結果生まれた究極の生命体です。
「オオカミのバンドね~」と言われたりするのを見かけますが違います。オオカミではありません。究極の生命体です。

メンバーは、トーキョー・タナカ(Vo.)、ジャン・ケン・ジョニー(Gt./Vo./Raps)、カミカゼ・ボーイ(Ba./Cho.)、DJ・サンタ・モニカ(DJ/Sampling)、スペア・リブ(Dr.)の5人。

画像は掲載しませんが、初めて見た人にとっては誰が誰だかわかりません。楽器を持っている写真であれば、楽器からメンバー名を判断することは容易そうです。しかし楽器に頼って判断するようではガウラーといえません。そこで顔の画像を入力することで誰なのかを判定する分類モデルを構築することにしました。この分類モデルをもとにしっかりと顔と名前を覚えましょう。

データセットの構築

画像分類モデルとして今回は、EfficientNetをファインチューニングします。そこでファインチューニングを行うためのデータセットの構築を行います。既存のデータセットが無いので、インターネット上に転がっているデータを収集します。
今回は1人につき30枚の画像を収集しました。全身が写っているものについては手作業でトリミングを行い、顔だけを切り取ります。
またメンバー以外の画像として、オオカミの画像を収集しました。今回構築する分類モデルの入力には、明らかにメンバーではない画像(車、食べ物など)の入力を考えていません。そのため負例(メンバーの画像を正例とした場合)としてオオカミの画像を使用しました。

image.png

ディレクトリ構成

作成したデータセットのディレクトリ構成は以下の通りです。
各メンバーのフォルダの中に30枚ずつ顔写真が入っています。合計180枚の画像データセットです。

data_root/
       ┝ dj/
       │  ┝ dj000.jpg
       │  ┝ dj001.jpg
       │   ︙
       │  └ dj029.jpg
       ┝ johny/
       ┝ kamikaze/      
       ┝ rib/
       ┝ tokyo/
       └ wolf/

実装

実装するにあたり、コードは以下の記事を非常に参考にさせていただきました。

Pytorch Lightningを使用したEfficientNetのファインチューニング

モジュールのインポート

まずは使用するモジュールをインポートします。
ただしpytorch_lightningのバージョンは1.6.4を使用しました。バージョンによって使い方が変わる厄介者です。

!pip install pytorch-lightning==1.6.4
import torch
import torch.nn as nn
import random
import numpy as np
import glob
import pytorch_lightning as pl
import timm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

seed値の固定

次にseed値の固定を行います。こちらの記事を参考にさせていただきました。

【seed、本当に固定できた?】PyTorchの再現性に関して公式資料を読む

# 各seedの固定
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def seed_worker():
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)
DataLoader(train_dataset,
           num_workers=n_workers,
           batch_size=batch_size, 
           shuffle=True, 
           worker_init_fn=seed_worker, 
           generator=g)

Datasetの定義

Datasetのclassを定義します。
__getitem__では、各メンバーの画像に対してラベル付けを行っています。具体的には、以下の表の通りです。

名前 ラベル
オオカミ 0
DJ・サンタ・モニカ 1
ジャン・ケン・ジョニー 2
カミカゼ・ボーイ 3
スペア・リブ 4
トーキョー・タナカ 5

ただしオオカミはメンバーではありません。

class MWAMDataset(Dataset):
    def __init__(self, img_path, transform=None):
        self.img_path = img_path
        self.transform = transform

    def __len__(self):
        return len(self.img_path)
    
    def __getitem__(self, index):
        img_path = self.img_path[index]
        img = Image.open(img_path)
        img = self.transform(img)

        # ラベル付け
        if "wolf" in img_path:
            label = 0
        elif "dj" in img_path:
            label = 1
        elif "johnny" in img_path:
            label = 2
        elif "kamikaze" in img_path:
            label = 3
        elif "rib" in img_path:
            label = 4
        elif "tokyo" in img_path:
            label = 5
        return img, label 

DataModuleの定義

DataModuleのclassを定義します。
member_splitでは、学習データが不均衡データにならないように、各メンバーを均等に振り分けています。
setupには学習時、検証時、テスト時に使うデータセットを設定します。

class MWAMDataModule(pl.LightningDataModule):
    def __init__(self, data_root_paths, img_size=224,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        super().__init__()
        self.data_root_paths = data_root_paths
        self.train_paths, self.val_paths, self.test_paths = self.member_split()
        self.n_workers = 20

        self.train_transforms = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

        self.val_test_transforms = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

    # データの振り分け
    def member_split(self):
        # train/val/testの画像データパスが格納される
        train_data_paths = []
        val_data_paths = []
        test_data_paths = []

        # 各メンバーについてデータを分割
        for member_path in self.data_root_paths:
            data = [img_path for img_path in glob.glob(member_path + "/*")]
            train_data, test_data = train_test_split(data, train_size=0.8, random_state=42)
            train_data, val_data = train_test_split(train_data, train_size=0.8, random_state=42)

            train_data_paths.extend(train_data)
            val_data_paths.extend(val_data)
            test_data_paths.extend(test_data)
        return train_data_paths, val_data_paths, test_data_paths        

    def setup(self, stage=None):
        if (stage == "fit") or (stage is None):
            self.train_dataset = MWAMDataset(self.train_paths, self.train_transforms)
            self.val_dataset = MWAMDataset(self.val_paths, self.val_test_transforms)
        if (stage == "test") or (stage is None):
            self.test_dataset = MWAMDataset(self.test_paths, self.val_test_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                            num_workers=self.n_workers,
                            batch_size=4, 
                            shuffle=True, 
                            worker_init_fn=seed_worker(), 
                            generator=g)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                           num_workers=self.n_workers,
                           batch_size=4, 
                           shuffle=False, 
                           worker_init_fn=seed_worker(), 
                           generator=g)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                           num_workers=self.n_workers,
                           batch_size=4, 
                           shuffle=False, 
                           worker_init_fn=seed_worker(), 
                           generator=g)

DataModuleのインスタンスを生成します。
引数には各メンバーの画像へのパスが含まれたリストを指定します。

# data_root_path = ["./data_root/rib",
#                   "./data_root/kamikaze",
#                   "./data_root/johnny",
#                   "./data_root/wolf",
#                   "./data_root/tokyo",
#                   "./data_root/dj"]
datasets = MWAMDataModule(data_root_paths)

モデルの定義

EfficientNetのモデル定義を行います。
モデルはtimmモジュールを使用して事前学習済みモデルをダウンロードします。引数のmodel_nameにモデル名を指定することでダウンロードできます。

class EfficientNet(pl.LightningModule):
    def __init__(self, model_name, n_classes):
        super().__init__()
        self.save_hyperparameters()

        self.model = timm.create_model(model_name, pretrained=True, num_classes=n_classes)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, n_classes)
        self.loss = nn.CrossEntropyLoss()
        self.lr = 1e-3
        self.epoch_cnt = 0
        

    def forward(self, img, label=None):
        pred = self.model(img)
        return pred
    
    def training_step(self, batch, batch_idx):
        img, label = batch
        outputs = self(img, label)

        # trsin loss
        train_loss = self.loss(outputs, label)

        # 正解率の計算
        outputs = F.softmax(outputs, dim=1)
        preds = np.argmax(outputs.cpu().detach().numpy(), axis=1)
        true_label = label.cpu().detach().numpy()
        train_accuracy = torch.tensor(accuracy_score(true_label, preds))
        return {"loss": train_loss, "train_accuracy": train_accuracy}
    
    def training_epoch_end(self, outputs):
        train_mean_loss = torch.stack([output["loss"] for output in outputs]).mean()
        train_accuracy = torch.stack([output["train_accuracy"] for output in outputs]).mean()
        self.log_dict({"train_mean_loss": train_mean_loss,
                       "train_accuracy": train_accuracy,
                       "epoch": self.epoch_cnt})
        self.epoch_cnt += 1
        print("train loss:", train_mean_loss)


    def validation_step(self, batch, batch_idx):
        img, label = batch
        outputs = self(img, label)

        # val loss
        val_loss = self.loss(outputs, label)
        self.log("val_loss", val_loss)

        # 正解率の計算
        outputs = F.softmax(outputs, dim=1)
        preds = np.argmax(outputs.cpu().detach().numpy(), axis=1)
        true_label = label.cpu().detach().numpy()
        val_accuracy = torch.tensor(accuracy_score(true_label, preds))
        return {"val_loss": val_loss, "val_accuracy": val_accuracy}
    
    def validation_epoch_end(self, outputs):
        val_mean_loss = torch.stack([output["val_loss"] for output in outputs]).mean()
        val_accuracy = torch.stack([output["val_accuracy"] for output in outputs]).mean()
        self.log_dict({"val_mean_loss": val_mean_loss,
                       "val_accuracy": val_accuracy,
                       "epoch": self.epoch_cnt})
        if self.epoch_cnt == 0:
            self.epoch_cnt += 1
        
        print("val loss:", val_mean_loss)
    
    def test_step(self, batch, batch_idx):
        img, label = batch
        outputs = self(img, label)
        outputs = F.softmax(outputs, dim=1)

        # 予測ラベルと正解ラベル
        preds = list(np.argmax(outputs.cpu().detach().numpy(), axis=1))
        true_label = list(label.cpu().detach().numpy())
        return {"true_label": true_label,
                "pred_label": preds}
    
    def test_epoch_end(self, outputs):
        # 混同行列の計算
        true_label = [x for output in outputs for x in output["true_label"]]
        pred_label = [x for output in outputs for x in output["pred_label"]]
        confusion_mat = confusion_matrix(true_label, pred_label)
        print(confusion_mat)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

モデルのインスタンスを生成します。
EfficientNetの事前学習済みモデルはefficientnet_b4を使用します。また今回はメンバー+オオカミの6値分類なので、n_classes=6とします。

model = EfficientNet(model_name="efficientnet_b4", n_classes=6)

学習の設定

pytorch-lightningでは、Trainerを用いて学習を行います。
引数のcallbacksにアーリーストッピングなどの設定を指定します。
今回は検証データを用いてアーリーストッピングを行うため、エポック数の最大max_epochsは1000とします。

early_stopping = EarlyStopping(
        min_delta=0.001,
        patience=5,
        verbose=False,
        monitor="val_mean_loss",
        mode="min")

trainer = pl.Trainer(accelerator="gpu",
                     devices=1, 
                     max_epochs=1000,
                     enable_checkpointing=True,
                     callbacks=[early_stopping])

学習開始

trainer.fit()で学習を行うことができます。以下の1行を実行するだけで学習が始まります。学習の実行の様子やモデルのパラメータなども自動で表示されてすごい。
モデルの保存は、save_checkpointを使用して保存します。

trainer.fit(model, datasets)
trainer.save_checkpoint("checkpoint_name.ckpt")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | EfficientNet     | 17.6 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
17.6 M    Trainable params
0         Non-trainable params
17.6 M    Total params
70.237    Total estimated model params size (MB)

テストデータでの評価

最後にテストデータを用いてモデルの評価を行います。
trainer.test()でテストデータに対してモデルの性能評価が始まります。引数のckpt_pathには保存されたモデルの重みのパスを指定してください。

test_results = trainer.test(datamodule=datasets, ckpt_path="checkpoint_name.ckpt")

評価結果を示します。
まずは混同行列です。ぱっと見良い感じに分類できてそうです。image.png

次に定量的に評価します。
評価指標としては正解率、マクロ平均適合率$P_{\mathrm{macro}}$、マクロ平均再現率$R_{\mathrm{macro}}$、マクロ平均F値$F_{\mathrm{macro}}$を使用します。また、マクロ平均適合率、マクロ平均再現率、マクロ平均F値は各評価指標をクラス数で割った値であり、以下の式で求められます。

\begin{align}
P_{\mathrm{macro}} &= \frac{1}{K}\sum_{i}^K P_i \\
R_{\mathrm{macro}} &= \frac{1}{K}\sum_{i}^K R_i \\
F_{\mathrm{macro}} &= \frac{1}{K}\sum_{i}^K F_i
\end{align}

ここで$P_i$、$R_i$、$F_i$はそれぞれラベル$i$についての適合率、再現率、F値を表します。また$K$は分類するクラス数であり、今回は$K=6$になります。
以下に結果を示します。

評価指標
正解率 0.778
マクロ平均適合率 0.824
マクロ平均再現率 0.778
マクロ平均F値 0.800

良い感じの性能です。MWAMのメンバー分類モデルを構築することができました。

おわりに

データセットの構築は手作業で行ったため荒い部分が多くありますが、分類精度がほどほどに高いモデルを構築できました。学習のデータ数を増やしたり背景の加工処理など改善の余地はありそうです。
また、入力として顔画像を使用することを想定しているため、体全体の写真が入力できないこともなんとか改善できたらよさそうです。

次はアプリケーション化などしてみたいなと思います。

参考記事

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?