0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

おーい、いその〜 AutoEncoder作ろ〜

0
Posted at

AutoEncoderとは

教師なし学習の一つ

今回は異常検知モデルとして実装していく
また、ここでは詳しいことは書かず、体系的に実装していくので
お手柔らかに & ご了承ください

データセット

勉強用というのと
色々著作権とかで面倒なので自前で用意した

  • 猫ちゃん真顔Ver → OK(30枚) // 学習6: 検証2: 評価2
  • 猫ちゃん赤面Ver → NG( 6枚)

としよう
用意したデータ(?)に対して異論は認めない。
sample_1.png

ゴール

以下とする
saple_2.png

これを見ると、パターンマッチングとかでもいいじゃん
という声も出そうだが
例えば正常画像に"バリエーション"がある場合
AutoEncoderのような画像生成による異常検知も有効な場面があるのでは、と思う

結果

sample_3.png

sample_4.png

とりあえずこれで真顔猫ちゃんか否かを検出できるAEができた

全コード

from glob import glob

import piq
import numpy as np
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torchvision import transforms as T

sns.set_style('darkgrid')

# 活性化関数としてMishを採用
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))

# AutoEncoder定義
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        # エンコーダー
        self.encoder = nn.Sequential(
            # RGBの3次元 x 128角の画像をInput
            # 特徴量をCNNを通じて少しづつ畳み込んでいく
            # Conv -> BatchNorm(層を正規化し、学習の安定を図る) -> Mish(活性化関数で非線形に複雑な特徴を掴む)
            nn.Conv2d(3, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            Mish(),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            Mish(),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            Mish(),

            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            Mish(),
        )

        # デコーダー
        self.decoder = nn.Sequential(
            # CNNでたたみ込まれた特徴量を元画像の次元数になるまで少しづつ広げていく            
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            Mish(),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            Mish(),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            Mish(),

            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            # Sigmoid関数で特徴量を0〜1の範囲にする
            nn.Sigmoid(),
        )

        self._init_weight()

    def _init_weight(self):
        # 重み初期化
        for mod in self.modules():
            if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(mod.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(mod.bias, 0)
            elif isinstance(mod, nn.BatchNorm2d):
                nn.init.constant_(mod.weight, 1)
                nn.init.constant_(mod.bias, 0)

    def forward(self, x):
        # 順伝播
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# 早期終了オブジェクト
class EarlyStopping:
    def __init__(self, patience: int):
        self.patience = patience
        self.history = []
        self.model_box = []

    def step(self, model: torch.nn.Module, loss: float):
        self.history.append(loss)
        model_no = len(self.history)
        torch.save(model, f'./models/tmp/model_no_{model_no}.pth')    

    def is_stop(self):
        # 検証の損失がpatience分減少しなければ、最適な損失時のモデルを返す
        best_epoch = np.argmin(self.history)
        if (len(self.history) - best_epoch) >= self.patience:
            best_model = torch.load(f'./models/tmp/model_no_{best_epoch}.pth', weights_only=False)
            return True, best_model
        return False, None


# データセットオブジェクト定義
class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_paths: list):
        self.transform = self._def_transform()
        self.img_paths = img_paths

    def __len__(self):
        return len(self.img_paths)

    def _def_transform(self):
        return T.Compose([
            T.ToTensor(),
        ])

    def __getitem__(self, idx):
        # 画像読み込み
        img = Image.open(self.img_paths[idx])
        # 画像をtorch型の0〜1の値に変換
        img = self.transform(img)
        return img


def train(
        model,
        optimizer,
        criterion,
        lr_scheduler,
        early_stopping,
        train_dataloader,
        val_dataloader,
    ):
    # 学習
    epochs = 300
    train_mean_losses = []
    val_mean_losses = []

    for epoch in range(epochs):
        model.train()

        # ミニバッチ学習
        losses = []
        for imgs in train_dataloader:
            # 勾配初期化(前回算出分の勾配初期化)
            optimizer.zero_grad()
            # バッチサイズ分の生成画像出力
            outputs = model(imgs)
            # 生成画像と元画像の差を損失関数で算出
            loss = criterion(outputs, imgs)
            # 誤差逆伝播
            loss.backward()
            # パラメータ更新
            optimizer.step()
            losses.append(loss.item())

        # 1Epoch分の損失を平均し、追加
        train_mean_losses.append(np.mean(losses))

        # 推論
        model.eval()
        losses = []
        # 推論モード(勾配を更新しない)に切り替え
        with torch.inference_mode():
            for imgs in val_dataloader:
                # 学習と同様、検証用画像を用いて画像生成し
                # 生成画像と元画像を比較、損失を算出する
                outputs = model(imgs)
                loss = criterion(outputs, imgs)
                losses.append(loss.item())
        val_mean_losses.append(np.mean(losses))

        print(f'''---------------------------------------
        Epoch     : {epoch + 1} / {epochs}
        Train loss: {train_mean_losses[epoch]}
        Val loss  : {val_mean_losses[epoch]}
        Diff      : {abs(train_mean_losses[epoch] - val_mean_losses[epoch])}
        Best score: epoch_{np.argmin(val_mean_losses)}, val loss_{np.min(val_mean_losses)}
        LR        : {lr_scheduler.get_last_lr()[0]}''')

        # 学習率スケジューラー更新
        lr_scheduler.step(np.mean(losses))

        # EarlyStopping更新, 学習継続判定
        early_stopping.step(model=model, loss=np.mean(losses))
        is_stop, best_model = early_stopping.is_stop()
        if is_stop:
            print(' - Early stopping - ')
            return best_model, train_mean_losses, val_mean_losses

    # 早期終了しなかった場合はベストモデルを最終エポック時のモデルとする
    if not is_stop:
        return model, train_mean_losses, val_mean_losses


def plot_learning_curve(train_mean_losses, val_mean_losses):
    # 学習曲線 可視化
    fig, ax = plt.subplots(dpi=150)

    x_labels = list(range(1, len(train_mean_losses)+1))

    ax.plot(x_labels, train_mean_losses, label='Train loss')
    ax.plot(x_labels, val_mean_losses, label='Val loss')

    ax.set_title('Epoch vs Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')

    ax.legend(loc='upper right')
    plt.show()


# 推論器
class Predictor:
    def __init__(self, model: nn.Module):
        self.model = model
        self.transform = T.Compose([
            T.ToTensor(),
        ])

    def prediction(self, img_path: str):
        # 画像読み込み
        img = Image.open(img_path)
        img = self.transform(img).unsqueeze(0)

        # 推論
        self.model.eval()
        with torch.inference_mode():
            pred_img = self.model(img)

        criterion = piq.SSIMLoss()
        loss = criterion(pred_img, img).item()

        # 0〜255の色情報に変換
        pred_img = pred_img.numpy() * 255.0
        # float -> int
        pred_img = pred_img.astype('uint8').squeeze()
        # PILで読み込めるように、[3, 128, 128] -> [128, 128, 3]へ
        pred_img = pred_img.transpose(1, 2, 0)
        return Image.fromarray(pred_img), loss


# 比較器
class Comparator:
    def __init__(self, predictor: Predictor):
        self.predictor = predictor

    def comparison(self, img_path: str):
        origin_img = Image.open(img_path)
        pred_img, loss = self.predictor.prediction(img_path=img_path)
        diff_img = Image.fromarray(np.array(origin_img) - np.array(pred_img))

        disp_img = Image.new('RGB', (128, 384))
        disp_img.paste(origin_img, (0, 0))
        disp_img.paste(pred_img, (0, 128))
        disp_img.paste(diff_img, (0, 256))

        return disp_img, loss

# 推論
def prediction(img_paths, comparator):
    losses = []
    disp_img_box = []
    for img_path in img_paths:
        disp_img, loss = comparator.comparison(img_path=img_path)
        losses.append(loss)
        disp_img_box.append(disp_img)

    fig, ax_list = plt.subplots(dpi=150, ncols=5)
    [ax.axis('off') for ax in ax_list]

    for ax, disp_img in zip(ax_list, disp_img_box):
        ax.imshow(disp_img)
    plt.show()

    return losses


# 損失差分可視化
def plot_diff(ok_losses, ng_losses):
    fig, ax = plt.subplots(dpi=120)

    ax.set_title('OK vs NG')
    ax.set_xlabel('Loss')
    ax.set_ylabel('Frequency')

    ax.hist(ok_losses, label='OK', range=[0.1, 0.3], bins=20, alpha=0.6)
    ax.hist(ng_losses, label='NG', range=[0.1, 0.3], bins=20, alpha=0.6)
    ax.vlines(x=0.165, ymin=0, ymax=2, colors='tomato', label='Threshold')
    ax.legend(loc='upper right')
    plt.show()


def main():
    # モデルオブジェクト生成
    model = AutoEncoder()
    # 損失関数, 今回はSSIMを採用, 単なるピクセル差(MSE)よりも、輝度/コントラスト/構造の比較を実施することで人間の視覚に近い評価を行う指標
    criterion = piq.SSIMLoss()
    # 最適化関数, 今回はRAdamを採用, 学習初期段階のパラメータ更新を安定させるべく導入
    optimizer = torch.optim.RAdam(model.parameters(), lr=0.001, weight_decay=0.001)
    # 学習率スケジューラー, patience分だけepochが進んでも損失が下がらなければ、LR *= gamma とする
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20)
    # 早期終了オブジェクト, patience分だけepochが進んでも損失が下がらなければ、学習切り上げ、最適なepoch時のモデルを返す
    early_stopping = EarlyStopping(patience=50)

    # 画像パスリスト取得
    ok_imgs = glob('./datasets/OK/*')
    ng_imgs = glob('./datasets/NG/*')
    # 学習データは正常画像のみ
    # 学習6 : 検証2 : 評価2 の割合とする
    train_imgs, val_test_imgs = train_test_split(ok_imgs, test_size=0.4)
    val_imgs, test_imgs = train_test_split(val_test_imgs, test_size=0.5)
    print('Datasets - ')
    print('Train', len(train_imgs), 'Val', len(val_imgs), 'Test', len(test_imgs))

    # データローダーオブジェクト定義
    train_dataloader = torch.utils.data.DataLoader(
        Dataset(train_imgs),
        batch_size=2,
        shuffle=True,
        drop_last=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        Dataset(val_imgs),
        batch_size=2,
        shuffle=True,
        drop_last=True,
    )

    # 学習
    best_model, train_mean_losses, val_mean_losses = train(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        lr_scheduler=lr_scheduler,
        early_stopping=early_stopping,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
    )

    # 学習曲線可視化
    plot_learning_curve(train_mean_losses=train_mean_losses, val_mean_losses=val_mean_losses)

    predictor = Predictor(model=best_model)
    comparator = Comparator(predictor=predictor)

    ok_losses = prediction(img_paths=test_imgs, comparator=comparator)
    ng_losses = prediction(img_paths=ng_imgs, comparator=comparator)

    # 損失差分可視化
    plot_diff(ok_losses=ok_losses, ng_losses=ng_losses)

if __name__ == '__main__':
    main()

これから

生成AIがハイスペックすぎて、カスタムでモデルを作ることも
かなり減ってきたと思う。

とはいえ、自分が興味のある分野や技術があれば
「なぜ」を追求してアウトプットしていくつもりだ
(できるだけツールベースではなく、研究者チックでもなく、"丁度良い本質"の理解を)

と言うわけでGANによる異常検知も後日記事にしようと思う(多分)

それではまた👋

参考文献

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?