はじめに
CVPR 2018から
[1] M. Sabokrou, et. al. "Adversarially Learned One-Class Classifier for Novelty Detection" のまとめ。
著者らの公式コード
https://github.com/khalooei/ALOCC-CVPR2018
私が作成したデモプログラム
https://github.com/masataka46/ALOCC
概要
- GANをベースにした異常検知のモデル
- 世界初?のend-to-endで学習する1クラス分類モデル
- 学習したモデルがそのまま異常検知の推論に使える
- 異常検知タスクで最高精度を達成
まず、上側のペンギンのデータで学習する。
推論時に正常データ(ペンギン)を入れると高いスコアとなり、異常データ(ペンギン以外)を入れると低いスコアとなって分別できる。
アーキテクチャ
アーキテクチャの全体像
アーキテクチャの全体像を以下の[1]Figure 2 で説明する。
まず R-Network(refineとreconstructionの意味を持つ)と D-Network がある。
$X$(例えばペンギン画像)にノイズを加えたものを R-Networkに入れ、auto-encoderのごとくencode、decodeし、$X'$ を出力する。
この $X'$ をD-Network(DiscriminatorとDetectの意味を持つ)に入れ、D-Networkとしてはこれはペンギンクラスでないと分類するよう学習する。
一方で R-NetworkとしてはこれをD-Networkにペンギンクラスと分類させるよう学習する。
D-Network にはペンギンの生画像 $X$ も入力する。これに対し D-Networkはペンギンと分類するよう学習する。
R-Network
R-Networkの詳細は以下の[1]Figure 3。
CNNを使ったencoder-decoderモデル。batch-normも使う。
活性化関数に関しては論文中で言及されてないがReLuやLeaky-ReLuだろう。poolingの代わりにconvのstrideで縮小する?
このR-Networkには正常なクラス(例えばペンギン)のデータにノイズを不可したものをひたすら入れるので、ペンギン的な特徴量を学び、refineされたペンギンを再構築すると考えられる。
D-Network
D-Network のアーキテクチャは以下の[1]Figure 4。
こちらも多層のCNN構造。最終的に正常クラス(例えばペンギン)の尤度を出力するよう学習する。
よって学習後は、変なペンギンやペンギン以外の画像が入力されると低い確率を出力する。これにより異常か正常かを判定できる。
学習の方法
R-Networkへの入力・出力
まず R-Network にはデータ分布 $p_t$ からサンプリングされた $X$ に正規分布に従うノイズ $\mu$ を加えたものを入力する。これを再構築&refineする。
\tilde{X} = (X \sim p_t) + (\mu \sim \mathcal{N} (0,\sigma^{2} \bf{I} \rm)) \rightarrow X' \sim p_t
ノイズを加える事で入力画像に含まれるノイズや歪みに対してロバストにする。
目的関数
adversarialなロスは以下のminmax式。
\mathcal{L}_{\mathcal{R} + \mathcal{D}}= \min_{\mathcal{R}} \max_{\mathcal{D}} \left( \mathbb{E}_{X \sim p_t} [\log \left( \mathcal{D} \left( X \right) \right) + \mathbb{E}_{\tilde{X} \sim p_t + \mathcal{N}_{\sigma}} [\log \left( 1 - \mathcal{D} \left( \mathcal{R} \left( \tilde{X} \right) \right) \right) ] \right)
この他に再構築ロスとして以下のL2 normも用いる。
\mathcal{L}_{\mathcal{R}} = \| X - X'\|^2
これら2つを足し合わせたロスの全体像は以下。
\mathcal{L} = \mathcal{L}_{\mathcal{R} + \mathcal{D}} + \lambda \mathcal{L}_{\mathcal{R}}
異常の判定方法
以上の学習を実行することにより、R-Networkは正常データ $X$ をよりrefineされたものに変換できるようになり、またD-Networkは正常データか否かを正確に判定できるようになるだろう。
まず、D-Networkのみを用いて異常か否かを判定する方法は以下。
OCC_{1}(X) =
\begin{cases}
Target \ Class & if \ \mathcal{D}(X) > \tau \\
Novelty \ (Outlier) & otherwise,
\end{cases}
D-Networkの出力である正常データに対する確率が閾値を超えたら正常と判定する。
これで十分 SOTA を達成できるが、折角 Refine するR-Networkができあがってるので、これを利用すると
OCC_{2}(X) =
\begin{cases}
Target \ Class & if \ \mathcal{D}(\mathcal{R}(X)) > \tau \\
Novelty \ (Outlier) & otherwise,
\end{cases}
となる。これだとさらに精度が上がる。
実験と結果
UCSD Ped2 dataset を用いた異常検知
UCSD Ped2 datasetを用いて正常データ(人の歩行画像)で学習させた後、正常データと異常データ(車等写ってる画像)をR-Networkに入力した時の再構築画像は以下。
左半分が正常データ、右半分が異常データ。
それぞれ2行目が元画像、1行目がそれにノイズを加えた入力画像、3行目が再構築画像。
正常画像はちゃんと再構築されてるが、異常画像は車や自転車などが再構築されてない。
MNISTを用いた異常検知
MNISTの1(正常データ)で学習させた後、6と7(異常データ)をR-Networkに入れた場合の再構築画像は以下。
6と7に対しても無理やり1っぽく出そうとしてる。
この時の正常データ、異常データに対するOCC scoreの分布は以下。
上側が $OCC_2$ 、下側が $OCC_1$ 。赤が正常データに対するスコア、黒が異常データに対するスコア。
$OCC_2$ の方が正常と異常とを区別できてる。
Caltech-256を用いた他の手法との比較
Caltech-256を用いた他の手法と比較した結果は以下。AUCとF値で比較している。
$OCC_2$ スコアを用いた本手法がおおむね一番いい。
作成したデモプラグラムとその学習結果
MNISTを用いて異常検知するプログラムはこちら。
https://github.com/masataka46/ALOCC
AUCの計算等、現在修正中です。
問題設定
MNISTの 1 を入力とし、この数字を再構築するように学習する。
その後、1 及び他の数字を入力し、1 に対してはちゃんと再構築され、他の数字に対しては再構築されていない事を確認する。
結果
200回くらい学習させた結果が以下。
右側が1に対する結果、左側が他の数字に対する結果。
それぞれ左から入力画像、再構築画像、その差分を可視化したもの。