LoginSignup
1
0

More than 3 years have passed since last update.

GMMをJavaScriptで実装した

Posted at

はじめに

色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の九回目です。

今回はGMMの実装について解説します。

デモはこちらから。(TaskをClustering/Anomaly Detection/Density Estimationにして、ModelのGaussian mixture modelを選択)
実際のコードはgmm.jsにあります。

なお、可視化部分については一切触れません。

概説

GMM(Gaussian Mixture Model)はクラスタリング、密度推定、異常検知に使われるモデルです。

複数の正規分布をいい感じに変形させて、データの密度を推定することがGMMの目的です。

理論およびアルゴリズムはこちらのスライドに詳しく記述されています。
実装は、このP.39にまとめられているアルゴリズムをコード起こしただけです。

以下、数式やノーテーションは上記サイトを参考にしています。

クラスタリング

クラスタリングで使用する場合は、各データがどの確率分布に属する可能性が最も高いかによってクラスタリングを行います。

\newcommand{\argmax}{\mathop{\rm arg~max}\limits}

\argmax_{k} \mathcal N \left(\bf x | \mu_{\it k}, \Sigma_{\it k} \right)

密度推定

密度推定は、確率分布の重み付き和によって次の通り計算ができます。

p(\bf x) = \sum_{{\it k} = 1}^{\it K} \pi_{\it k} \mathcal N \left(\bf x | \mu_{\it k}, \Sigma_{\it k} \right)

異常検知

クラス内にコードは存在しませんが、正常である確率は以下の通り計算しています。

\begin{eqnarray}
p'(\bf x) &=& 1 - \prod_{{\it k} = 1}^{\it K} \exp\left(-\pi_{\it k} \mathcal N \left(\bf x | \mu_{\it k}, \Sigma_{\it k} \right) \right) \\
&=& 1 - \exp(-p(\bf x))
\end{eqnarray}

この確率が一定以下の場合に異常と判定することで、異常検知ができます。

分散の計算

分散は、分布の平均に近いデータ点が一つとなると、正則ではなくなるようです。
すると確率が計算できなくなり、以降の処理に問題が発生します。
なので更新時、分散の対角成分に微少量を足すようにしています。

コード

GMM

class GMM {
    // see https://www.slideshare.net/TakayukiYagi1/em-66114496
    // Anomaly detection https://towardsdatascience.com/understanding-anomaly-detection-in-python-using-gaussian-mixture-model-e26e5d06094b
    //                   A Survey of Outlier Detection Methodologies. (2004)
    constructor(d) {
        this._k = 0;
        this._d = d;
        this._p = [];
        this._m = [];
        this._s = [];
    }

    add() {
        this._k++;
        this._p.push(Math.random());
        this._m.push(Matrix.random(this._d, 1));
        const s = Matrix.randn(this._d, this._d);
        this._s.push(s.tDot(s));
    }

    clear() {
        this._k = 0;
        this._p = [];
        this._m = [];
        this._s = [];
    }

    probability(data) {
        return data.map(v => {
            const x = new Matrix(this._d, 1, v);
            const prob = [];
            for (let i = 0; i < this._k; i++) {
                const v = this._gaussian(x, this._m[i], this._s[i]) * this._p[i];
                prob.push(v);
            }
            return prob;
        })
    }

    predict(data) {
        return data.map(v => {
            const x = new Matrix(this._d, 1, v);
            let max_p = 0;
            let max_c = -1;
            for (let i = 0; i < this._k; i++) {
                let v = this._gaussian(x, this._m[i], this._s[i]);
                if (v > max_p) {
                    max_p = v;
                    max_c = i;
                }
            }
            return max_c;
        });
    }

    _gaussian(x, m, s) {
        const xs = x.copySub(m);
        return Math.exp(-0.5 * xs.tDot(s.inv()).dot(xs).value[0]) / (Math.sqrt(2 * Math.PI) ** this._d * Math.sqrt(s.det()));
    }

    fit(datas) {
        const n = datas.length;
        const g = [];
        const N = Array(this._k).fill(0);
        const x = [];
        datas.forEach((data, i) => {
            const ns = [];
            let s = 0;
            const xi = new Matrix(this._d, 1, data);
            for (let j = 0; j < this._k; j++) {
                const v = this._gaussian(xi, this._m[j], this._s[j]) * this._p[j];
                ns.push(v || 0);
                s += v || 0;
            }
            const gi = ns.map(v => v / (s || 1.0));
            g.push(gi);
            x.push(xi);
            gi.forEach((v, j) => N[j] += v);
        });

        for(let i = 0; i < this._k; i++) {
            const new_mi = new Matrix(this._d, 1);
            for (let j = 0; j < n; j++) {
                new_mi.add(x[j].copyMult(g[j][i]));
            }
            new_mi.div(N[i]);
            this._m[i] = new_mi;

            const new_si = new Matrix(this._d, this._d);
            for (let j = 0; j < n; j++) {
                let tt = x[j].copySub(new_mi);
                tt = tt.dot(tt.t);
                tt.mult(g[j][i]);
                new_si.add(tt);
            }
            new_si.div(N[i]);
            new_si.add(Matrix.eye(this._d, this._d, 1.0e-8))
            this._s[i] = new_si;

            this._p[i] = N[i] / n;
        }
    }
}
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0