0
0

More than 3 years have passed since last update.

Mean ShiftをJavaScriptで実装した

Posted at

はじめに

色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の四回目です。

今回はMeanShiftの実装について解説します。ほぼ解説はしていませんが。

デモはこちらから。(TaskをClusteringにして、ModelのMean Shiftを選択)
実際のコードはmean_shift.jsにあります。

なお、可視化部分については一切触れません。

概説

実装はこのサイトを参考にしました。
この数式をだいたいそのままコードに書き起こしています。

なので具体的な流れはそちらを参照してもらい、実装で異なっている部分のみ説明します。

数式との相違点

二点、異なる部分があります。

  1. 更新式

    上記サイトでは

    \vec{\nabla} f_K \left(\vec{x}\right) = \frac{1}{h^{2}} f_G \left(\vec{x}\right) \vec{s}_G \left(\vec{x}\right)
    

    となっていますが、実装では

    \vec{\nabla} f_K \left(\vec{x}\right) = \vec{s}_G \left(\vec{x}\right)
    

    としています。
    fd中のreturn文のコメントアウトしてある方が数式と一致する更新式です。

    なお、なぜこのような変更を行ったのかは、記憶のかなたです。とりあえず動いているようには見えます。

  2. 体積の計算

    二次元(円の面積)に固定されています。
    ただし体積は$f_G \left(\vec{x}\right)$の計算にのみ使用するため、動作には影響はありません。

    正しい更新式を用いる場合は、超球の体積計算式を使用するように変更する必要があります。

    超球の体積はWikipediaより、

    V_n (r) = \frac{n^{\pi / 2}}{\Gamma \left(\frac{n}{2} + 1\right)} r^n
    

    であり、またガンマ関数$\Gamma(x)$は同じくWikipediaより、自然数$n$について同じく

    \begin{eqnarray}
    \Gamma(n + 1) &=& n! \\
    \Gamma(n + \frac{1}{2}) &=& \frac{(2n - 1)!!}{2^n}\sqrt{\pi}
    \end{eqnarray}
    

    が成り立ちます。
    ただし、数式をそのまま実装すると桁数が大変なことになるので、全体を対数でくくったほうが良いです。

    超球の体積計算は、K近傍法に基づく密度推定手法に実装したものがあります。
    この中のilogvが超球の体積の逆数の対数を取ったものになります。つまり$\log{\frac{1}{V_n(r)}}$です。

    _logGamma(z) {
        let x = 0
        if (Number.isInteger(z)) {
            for (let i = 2; i < z; i++) {
                x += Math.log(i)
            }
        } else {
            const n = z - 0.5
            x = Math.log(Math.sqrt(Math.PI)) - Math.log(2) * n
            for (let i = 2 * n - 1; i > 0; i -= 2) {
                x += Math.log(i)
            }
        }
        return x
    }
    
    predict(data) {
        const ps = this._near_points(data)
        const r = ps[ps.length - 1].d
        const d = data.length
        const ilogv = this._logGamma(d / 2 + 1) - d / 2 * Math.log(Math.PI) - d * Math.log(r)
        return Math.exp(ilogv) * this.k / this._p.length
    }
    

推論

推論は、自身を中心とした超球の位置(this._centroids)によって決定します。

同一のクラスタとする条件を、中心が厳密に一致する、とするとうまくいかない場合が多いので、一定の誤差を許容するようにします。
これをthis._thresholdより設定し、この距離よりも近い場合は同一のクラスタとするようにしました。

コード

class MeanShift {
    // see http://seiya-kumada.blogspot.com/2013/05/mean-shift.html
    // see http://takashiijiri.com/study/ImgProc/MeanShift.htm
    constructor(h, threshold) {
        this._x = null;
        this._centroids = null;
        this._h = h;
        this._threshold = threshold;
        this._categories = 0;
    }

    get categories() {
        return this._categories;
    }

    get h() {
        return this._h;
    }

    set h(value) {
        this._h = value;
    }

    set threshold(value) {
        this._threshold = value;
    }

    _distance(a, b) {
        return Math.sqrt(a.reduce((s, v, i) => s + (v - b[i]) ** 2, 0))
    }

    init(data) {
        this._x = data;
        this._centroids = this._x.map(v => [].concat(v));
    }

    predict() {
        this._categories = 0;
        const p = []
        for (let i = 0; i < this._centroids.length; i++) {
            let category = i;
            for (let k = 0; k < i; k++) {
                if (this._distance(this._centroids[i], this._centroids[k]) < this._threshold) {
                    category = p[k];
                    break;
                }
            }
            if (category === i) this._categories++;
            p[i] = category;
        }
        return p;
    }

    fit() {
        if (this._centroids.length === 0 || this._x.length === 0) {
            return;
        }
        const d = this._centroids[0].length;
        const Vd = Math.PI * (this._h ** 2);
        const G = (x, x1) => x.reduce((acc, v, i) => acc + ((v - x1[i]) / this._h) ** 2, 0) <= 1 ? 1 : 0;
        const mg = (gvalues) => {
            let s = 0;
            let v = Array(this._x[0].length).fill(0);
            this._x.forEach((p, i) => {
                if (gvalues[i]) {
                    s += gvalues[i];
                    v = v.map((a, j) => a + p[j] * gvalues[i])
                }
            });
            return v.map((a, i) => a / s);
        };
        const sg = (x, gvalues) => mg(gvalues).map((v, i) => v - x[i]);
        const fg = (gvalues) => {
            return gvalues.reduce((acc, v) => acc + v, 0) / (gvalues.length * Vd);
        }
        const fd = (x) => {
            let gvalues = this._x.map(p => G(x, p));
            return sg(x, gvalues);
            //return sg(x, gvalues).mult(2 / (this._h ** 2) * fg(gvalues));
        };
        let isChanged = false;
        this._centroids = this._centroids.map((c, i) => {
            let oldPoint = c;
            const v = fd(c);
            const newPoint = c.map((a, i) => a + v[i])
            isChanged |= oldPoint.some((v, i) => v !== newPoint[i]);
            return newPoint;
        });

        return isChanged;
    }
}

さいごに

スコープの近い場所に同じ変数名を使う(最後の方のiv)、という愚行を犯す。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0