LoginSignup
3
4

More than 1 year has passed since last update.

お題は不問!Qiita Engineer Festa 2023で記事投稿!

【PyTorch】実装有:VAEを使った継続学習異常検知手法:Continual Learning for Anomaly Detection with Variational Autoencoder

Posted at

はじめに

継続学習(CL;Continual Learning)とは、動的に学習データが変化する環境下において、破壊的忘却することなくモデルの学習を可能とすることを目的とした機械学習の一分野となります.
※破壊的忘却とは、単一のネットワークを複数のタスクの学習に利用する場合、過去に学んだタスクに対する精度が、新しいタスクの学習時に悪化する事象のことを指します

ICLRやICMLをはじめとしたAI関連のトップカンファレンスにおいても、CLに関する論文の投稿数は増加傾向にあり、注目されている分野といえます.
主にCLは識別モデルの学習という問題設定で議論されることが多いですが、生成モデルをもちいる「異常検知」問題においても適用されます.
今回はCLの問題設定の上で、異常検知に関する手法を提案した論文Continual Learning for Anomaly Detection with Variational Autoencoderを紹介します.

[1]F. Wiewel, B. Yang, Continual learning for anomaly detection with variational autoencoder, in: ICASSP 2019, pp. 3837–3841

非公式ですが私の方で実装をしました.こちらをご参考に頂ければ幸いです.

論文解説

概要

異常検知の問題設定においても、破壊的忘却の問題は発生し得るが、継続学習の活用により、これを解決することができます.異常検知では変分オートエンコーダ(VAE;Variational Autoencoder)が良く用いられますが、その生成機能は異常検知にいては無視されます.→再構成を活用するため
※異常検知に関する基本的な内容についてはこちらをご参照ください.

Fig. 1 異常検知で用いられるVAEのイメージ.正常データのみの1クラスデータセットを用意し、入出力間の誤差(再構成誤差)を最小化するようにモデルを学習させる.異常データの入力時には再構成誤差が大きくなる仕組みを利用して異常検知する.

この論文では、まず異常検知を継続学習の問題設定にアレンジメントしたうえで、VAEの生成機能を活用したDeep Generative Replayライクな効率的な継続学習手法を提案しています.
継続学習の手法という観点では、極めてシンプルでありながらも有効かつ効率的であることが評価できるほか、
継続学習×異常検知という議題をもちこんだ点も評価できるといえます.

問題設定

D1,D2,..Di.,DNからなる正常データを用意します.各データセットDiは画像データ{x1,x2,...xN}からなります.
iはi番目のタスクを意味します.
目標はこのデータセットが継続的に与えられるときに、すべてのタスクに対してVAEを学習させることです.
重要な制約は、(継続学習の問題設定として一般的ですが)タスクがモデルに対して逐次的に与えられますが、過去学習したタスクは二度と参照できません.

MNISTで具体的に例をお示しします.1~9のラベルがついたデータセットをそれぞれタスク1~9として分割します.
最初にデータセットD1を学習し、次はD2を学習しますが、D2の学習時には1のラベルがついたデータD1は参照できません.

提案手法

image.png

Fig. 2 [1]の手法の図

上記の問題設定でVAEを学習すると、過去学習したタスクの生成は破壊的忘却によりできなくなります.
そこで一つ前のタスク(i-1)を学習したDecoderにいくつかサンプルを生成させてそれを今のタスク(i)の学習に利用します.異常検知において無視されるVAEの生成モデルとしての利点を効率的に活用しているといえます.
生成したサンプルを活用する手法は既出で、[2] Deep Generative Replay(DGR)にて提案されたものがベースです.

[2] H. Shin et al. “Continual learning with deep generative replay.” In Advances in Neural Information Processing Systems, pp. 2990-2999. (2017).

[2]の手法はGANベースで識別モデルのほか生成器が必要となりますが、本手法ではVAEだけでDGRができます.

検証

MNISTなどの公開データを継続学習の問題設定に落とし込みます.
まずデータセットをD0,D1,D2,..Di.,D9に分割し、D0を異常クラスとして定義し、それ以外を正常として定義します.
正常クラスは各タスクが完了すると同時にその定義を拡張します.
例:タスク2完了時→正常クラスは1と2
タスク5完了時→正常クラスは1から5

学習時:
タスクiにおいてモデルに与えられるのはDiのみです.

テスト時:
異常クラス:0
正常クラス:1~i
として、各タスクが終了したときにAUROCを用いた異常検知の精度評価を実施します.

比較対象として、
Upper Bound:過去のタスクも参照可能として学習したケース
GR:過去のタスクは参照不可、今回の手法を使ったケース
EWC:過去のタスクは参照不可、破壊的忘却を抑止する継続学習の一手法(従来手法)
Lower Bound:過去のタスクは参照不可、なんの対策もしないケース

image.png

結果として、本手法の有効性が確認できました.タスクが進んでもAUROCがUpperに張り付いていることがわかります.
一方でEWCやLowerについてはタスクが進むにつれて破壊的忘却を起こしてしまい、異常検知能力が下がっていることがわかります.

実装

コード解説

実装については私が空気を読んで補完した部分がある点ご容赦ください.
一般のVAEを用いた異常検知と大差ないのですが特筆する部分はDecoderの生成データをmini-batchに組み込んで学習させた部分と思います.
シナリオ(タスクの順番ことです)が0以外のときには、一つ前のタスクのdecoderを呼び出しています。外部にweightを保存しております。

main.oy
if s != 0:
        p_decoder = torch.load(f'./save/weight/decoder_scenario{int(s - 1)}.pth')
        prev_decoder.load_state_dict(p_decoder)
        prev_decoder.eval()

    for epoch in range(opt.num_epochs):
        i = 0
        for img,_ in train_loader:
            if s != 0:
                img = torch.cat([img.to(device).view(img.size(0), -1),prev_decoder(torch.randn(int(opt.batch_size/2),latent_dim).to(device))],dim=0)

            if opt.flatten:
                x = img.view(img.size(0), -1)
            else:
                x = img

            if cuda:
                x = Variable(x).cuda()
            else:
                x = Variable(x)

            xhat, mean, log_var = model(x)
            loss = loss_function(x, xhat, mean, log_var)

            losses[epoch] = losses[epoch] * (i / (i + 1.)) + loss * (1. / (i + 1.))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

極めて単純ですが有効性はあるのでよい手法だと思います.
詳細についてはこちらのGithubを参照願います。

検証結果

image.png
DGRが機能していることがわかります.

image.png

Upperでは当然破壊的忘却を起こさないので、正常データに対する再構成が正しくできていることがわかります.

image.png

DGRにより、過去のタスクも生成できています.Upperと比較すると、生成データを学習している影響からか、ぼやけが発生しており、再構成品質は劣る印象です.

image.png
過去学習したタスクのforgettingが発生しています.現在学習した6しか再構成ができません.

image.png

過去学習したタスクのサンプル生成に成功していますが、Decoder自体も若干の忘却を起こしているためか、過去タスクの生成品質が低い印象です.

おわりに

今回は異常検知に関する手法を提案した論文Continual Learning for Anomaly Detection with Variational Autoencoderを紹介しました.
論文実装し追試したところ本手法の有効性を確認しました.
一方で本手法自体も過去タスクに対する忘却が若干あるように見受けられるので、こうした部分の改善余地はあると思います.再掲:非公式ですが私の方で実装をしました.こちらをご参考に頂ければ幸いです.
最後までご覧いただきありがとうございました.

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4