はじめに
One-Class SVMを用いて異常検知の実験をしてみました。手法の概要とPythonでの実装を載せます。
One-Class SVM
One-Class SVMとは
One-Class SVM(One-Class Support Vector Machine:OCSVM)は、SVMを領域推定問題に応用した手法です。
通常のSVMでは2つのクラスに属するデータから、そのクラスデータを分離する超平面を学習するのに対して、OCSVMでは主に1つのクラスデータのみから、そのクラスデータが分布する領域を学習します。
学習した正常データの分布の外にあるデータを異常と判定することで、異常検知に応用できます。
OCSVMの手順
- データ$\mathbf{x}$を非線形関数$\mathbf{\Phi}$で特徴空間に写像する
- 原点からの距離(マージン)を最大化するように分離超平面$f(\mathbf{x})$を求める
- $f(\mathbf{x})$の値によって、データ点$\mathbf{x}$が特徴空間で分離超平面のどちら(内側か外側)に分布しているかを決定する
数式
$d$次元の$n$個のデータ
\mathbf{x}_{1},\mathbf{x}_{2},\ldots,\mathbf{x}_{n}\in\mathcal{X}\subset\mathbb{R}^{d}
が与えられたとします。$F$を内積空間とし、非線形写像$\mathbf{\Phi}:\mathcal{X}\to F$を考えます。この非線形写像$\mathbf{\Phi}$を用いることで、特徴空間における内積をカーネル関数を使って計算できます。
k(\mathbf{x},\mathbf{y})=\mathbf{\Phi}(\mathbf{x})'\mathbf{\Phi}(\mathbf{y})
データセットを原点から分離するために、OCSVMは以下の最適化問題を解きます。
\min_{\mathbf{w}\in F,~\mathbf{\xi}\in\mathbb{R}^{n},~\rho\in\mathbb{R}} \quad \frac{1}{2}||\mathbf{w}||^{2}+\frac{1}{\nu n}\sum_{i=1}^{n}\xi_{i}-\rho
\mbox{subject to} \quad \mathbf{w}'\mathbf{\Phi}(\mathbf{x}_{i})\geq \rho-\xi_{i}, \quad \xi_{i}\geq 0 \qquad (i=1,\ldots,n)
各項の意味
- $\mathbf{w}:~$決定境界の特徴空間での法線ベクトル
- $\mathbf{\Phi}(\cdot):~$非線形写像
- $\rho:~$境界のオフセット(マージン)
- $\xi_{i}:~$スラック変数(領域外でも正常とみなす許容度)
- $\nu\in(0,1]:~$異常許容量(最大で割合$\nu$のデータが境界外にいてもOK)
この最適化問題は:
- できるだけ原点から離れた位置に超平面を置く($\rho$を大きくする)
- ただし全データがその内側に収まらなくてもよく、少し外れてもよい($\xi_{i}$を許容)
- モデルが複雑になりすぎないように($||\mathbf{w}||^{2}$を最小化)
という3つのバランスを取っています。
実際にはカーネル関数を使って変形した以下の双対問題を解きます。
\min_{\alpha} \quad \frac{1}{2}\sum_{i,j=1}^{n}\alpha_{i}\alpha_{j}k(\mathbf{x}_{i},\mathbf{x}_{j})
\mbox{subject to} \quad 0\leq\alpha_{i}\leq\frac{1}{\nu n}, \quad \sum_{i=1}^{n}\alpha_{i}=1
- $\alpha_{i}:~$ラグランジュ乗数
- $k(\cdot,\cdot)=\mathbf{\Phi}(\cdot)'\mathbf{\Phi}(\cdot):~$カーネル関数
訓練後のデータ点$\mathbf{x}$に対して、決定関数を以下の式で計算します。
f(\mathbf{x})=\mbox{sgn}\left(\sum_{i=1}^{n}\alpha_{i}k(\mathbf{x}_{i},\mathbf{x})-\rho\right)
- $f(\mathbf{x})=1~$ならば正常
- $f(\mathbf{x})=-1~$ならば異常
と判断します。
Pythonでの実装
動作環境
Python : 3.9.12
numpy : 1.21.5
matplotlib : 3.5.1
sklearn : 1.0.2
Pythonコード
ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import OneClassSVM
仮想データの生成
# トレインデータ(正常)
X_train = 0.3 * np.random.randn(100, 2)
X_train = np.r_[X_train + 2, X_train - 2]
# テストデータ(正常 + 異常混在)
X_test = 0.3 * np.random.randn(20, 2)
X_test = np.r_[X_test + 2, X_test - 2]
X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2))
トレインデータは、2次元正規分布から100個の乱数を生成し、100個それぞれのデータに対して$(2,2)$と$(-2,-2)$に平行移動させて結合しています($d=2,~n=200$)。$(2,2)$と$(-2,-2)$付近にクラスターをもつデータになります。
テストデータは、2次元正規分布からの20個の乱数を$(2,2)$と$(-2,-2)$に平行移動させた正常データ($d=2,~n=40$)と区間$[-4,4]$上の一様分布から発生させた異常データ($d=2,~n=20$)なっています。
OCSVMのモデル学習
clf = OneClassSVM(kernel="rbf", gamma=0.1, nu=0.05) # nu=異常許容量
clf.fit(X_train)
カーネル関数はガウスカーネル($\gamma=0.1$)、異常許容量$\nu=0.05$として、トレインデータで学習しています。
テストデータの予測
y_pred_test = clf.predict(X_test)
y_pred_outliers = clf.predict(X_outliers)
学習したモデルを用いて、テストデータ(正常$+$異常)の予測を行っています。
データと分離曲線の可視化
最後にデータ点と学習した分離曲線をプロットします。
plt.figure(figsize=(8, 6))
xx, yy = np.meshgrid(np.linspace(-5, 5, 500), np.linspace(-5, 5, 500))
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.Blues_r)
plt.contour(xx, yy, Z, levels=[0], linewidths=2, colors='red')
plt.scatter(X_train[:, 0], X_train[:, 1], c='white', edgecolors='k', label="Training data")
plt.scatter(X_test[:, 0], X_test[:, 1], c='blue', edgecolors='k', label="Test data")
plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='red', edgecolors='k', label="Outliers")
plt.legend()
plt.title("One-Class SVM for Anomaly Detection")
plt.show()
- 白点:トレインデータ
- 青点:テストデータ(正常)
- 赤点:テストデータ(異常)
- 赤のライン:OCSVMで学習した分離曲線
になります。概ね白点(トレインデータ)と青点(正常なテストデータ)は分離曲線の内側に入り、赤点(異常なテストデータ)は分離曲線の外側に分布しており、異常検知モデルとして機能していることがわかります。
おわりに
One-Class SVMの手法の概要とPythonでの実装についてまとめました。
今後の課題としては
- ハイパーパラメータ($\nu,~\gamma$)のチューニング
- 適切なカーネル関数の選択
- 他の異常検知手法との比較
などに取り組んでみたいです。まだカーネル法の理解が浅いので、カーネル法の理論的な勉強もしていきたいです。
参考文献
- [1] Estimating the Support of a High-Dimensional Distribution
https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-99-87.pdf - [2] One-Class Support Vector Machine OCSVM
https://datachemeng.com/wp-content/uploads/oneclasssupportvectormachine.pdf