LoginSignup
1
1

More than 3 years have passed since last update.

k-means法とそれに関連したモデルのJavaScriptによる実装

Last updated at Posted at 2020-10-11

はじめに

色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の一回目です。

k-means法に関連した以下のモデルの実装について解説します。

  • k-means
  • k-means++
  • k-medois
  • Neural Gas

デモはこちらから。(TaskをClusteringにして、ModelのK-MeansまたはNeural Gasを選択)
実際のコードはkmeans.js、Neural Gasはneural_gas.jsにあります。

数学的な話は分かりやすく説明する自信が無いため、ほとんど行いません。
また、可視化部分については一切触れません。

概説

ほとんどの処理は同一であり、モデルによって処理の変わる部分は
- 新規のセントロイドの追加(add
- セントロイドの位置の移動(move
の二つのみです。

なので、これらの処理をコンストラクタで注入するようにします。
デザインパターンでいうところの、Dependency Injectionです。

まずは注入される側の共通処理を実装するクラスを示し、次に、モデル別の処理を行うクラスを示します。

共通処理

KMeansModelで、全てのモデルで共通の処理を実装します。

コンストラクタでモデル別の処理を行うクラスのインスタンスを受け取ります。
fitで学習処理を、predictで推論処理を行います。共に、二次元配列を受け取ります。
学習処理fitはセントロイドの移動を一回だけ行うようにし、移動距離を返却することで終了判定ができるようにしています。

class KMeansModel {
    constructor(method = null) {
        this._centroids = [];
        this._method = method || new KMeans();
    }

    get centroids() {
        return this._centroids;
    }

    get size() {
        return this._centroids.length;
    }

    set method(m) {
        this._method = m;
    }

    _distance(a, b) {
        let v = 0
        for (let i = a.length - 1; i >= 0; i--) {
            v += (a[i] - b[i]) ** 2
        }
        return Math.sqrt(v)
    }

    add(datas) {
        const cpoint = this._method.add(this._centroids, datas);
        this._centroids.push(cpoint);
        return cpoint;
    }

    clear() {
        this._centroids = [];
    }

    predict(datas) {
        if (this._centroids.length === 0) {
            return;
        }
        return datas.map(value => {
            return argmin(this._centroids, v => this._distance(value, v));
        });
    }

    fit(datas) {
        if (this._centroids.length === 0 || datas.length === 0) {
            return 0;
        }
        const oldCentroids = this._centroids;
        this._centroids = this._method.move(this, this._centroids, datas);
        const d = oldCentroids.reduce((s, c, i) => s + this._distance(c, this._centroids[i]), 0);
        return d;
    }
}

method.addの引数には、現在のセントロイドの配列と、分類対象データの配列を渡します。
method.moveの引数には、自分自身と現在のセントロイドの配列、分類対象データの配列を渡します。なおセントロイドの配列は第一引数から取ることができるので、渡さない方がいいかもしれません。

以降は、コンストラクタの引数methodに渡されるインスタンスのクラスの解説です。

k-means

k-means法では、新規のセントロイドはランダムに選択します。ただし、完全に一致するデータを取ると困るので、既存のセントロイドに近すぎる場合は選択しなおしています。

データ数に対してセントロイドの数が上回る場合は無限ループに陥る可能性がありますが、その制御は呼び出し元で行っています。
また、全てのデータ同士の距離が小さすぎる場合にも無限ループに陥りますが、ここでは発生しないものとして無視しています。

class KMeans {
    add(centroids, datas) {
        centroids = centroids.map(c => new DataVector(c));
        while (true) {
            const p = new DataVector(datas[randint(0, datas.length - 1)]);
            if (Math.min.apply(null, centroids.map(c => p.distance(c))) > 1.0e-8) {
                return p.value;
            }
        }
    }

なお、DataVectorは配列に対してベクトル演算を実行できるようにするクラスで、distance関数によって他のDataVectorインスタンスとのユークリッド距離を計算しています。

セントロイドの移動先は、自身が最も近いデータ群の重心となります。
一度データの分類結果を取得(predict)して、各セントロイドに属するデータの重心を計算しています。

    _mean(d) {
        const n = d.length
        const t = d[0].length
        const m = Array(t).fill(0);
        for (let i = 0; i < n; i++) {
            for (let k = 0; k < t; k++) {
                m[k] += d[i][k]
            }
        }
        return m.map(v => v / n);
    }

    move(model, centroids, datas) {
        let pred = model.predict(datas);
        return centroids.map((c, k) => {
            let catpoints = datas.filter((v, i) => pred[i] === k);
            return this._mean(catpoints)
        });
    }
}

k-means++

k-means++法は、k-means法と比較して新規のセントロイドの選択方法が変わるだけなので、KMeansクラスを継承したクラスを作ります。

新規のセントロイドの選択は、各データと最近傍セントロイドとの距離によって確率的に決めます。
それら距離の累積値を累積分布関数と見立て、それに対する逆関数法を用いて選択しています。なお実装上、$[0,1]$に正規化していません。

export class KMeanspp extends KMeans {
    add(centroids, datas) {
        if (centroids.length == 0) {
            return datas[randint(0, datas.length - 1)]
        }
        centroids = centroids.map(c => new DataVector(c));
        const d = datas.map(d => new DataVector(d)).map(p => Math.min.apply(null, centroids.map(c => p.distance(c))) ** 2);
        const s = d.reduce((acc, v) => acc + v, 0);
        let r = Math.random() * s;
        for (var i = 0; i < d.length; i++) {
            if (r < d[i]) {
                return datas[i];
            }
            r -= d[i];
        }
    }
}

k-medois

k-medois法は結局はセントロイドの移動先が変わるだけなので、こちらもKMeansクラスを継承したクラスを作ります。
セントロイドの移動先は、そのクラスに属するデータの中で、他のデータとの距離の総和が最も小さいデータになります。

class KMedoids extends KMeans {
    move(model, centroids, datas) {
        let pred = model.predict(datas);
        return centroids.map((c, k) => {
            let catpoints = datas.filter((v, i) => pred[i] === k).map(v => new DataVector(v));
            if (catpoints.length > 0) {
                let i = argmin(catpoints, cp => {
                    return catpoints.map(cq => cq.distance(cp)).reduce((acc, d) => acc + d, 0);
                });
                return catpoints[i].value;
            } else {
                return c;
            }
        });
    }
}

なおargminは最小の値の位置を返す関数です。第一引数で配列を、第二引数で比較する値を返す関数を渡します。

Neural Gas

数学的にはSelf-organization mapやNeural Networkから解釈するようですが、ここでの実装はk-means法をベースに作成しました。
Wikipediaでの$w$がセントロイドに該当します。

新規のセントロイドの追加はKMeansと同様としていますが、セントロイドの移動先計算が大きく異なります。
やはりこれもKMeansクラスを継承したクラスとします。

class NeuralGas extends KMeans {
    // https://en.wikipedia.org/wiki/Neural_gas
    constructor() {
        this._l = 1;
        this._eps = 1;
        this._epoch = 0;
        this._sample_rate = 0.8;
    }

    move(model, centroids, datas) {
        const x = datas.filter(v => Math.random() < this._sample_rate).map(v => new DataVector(v));
        this._epoch++;
        const cvec = centroids.map(c => new DataVector(c));
        const distances = x.map(v => {
            let ds = cvec.map((c, i) => [i, v.distance(c)])
            ds.sort((a, b) => a[1] - b[1]);
            ds = ds.map((d, k) => [d[0], d[1], k])
            ds.sort((a, b) => a[0] - b[0]);
            return ds;
        })
        return cvec.map((c, n) => {
            const updates = distances.map((v, i) => x[i].sub(c).mult(this._eps * Math.exp(-v[n][2] / this._l)))
            const update = updates.slice(1).reduce((acc, v) => acc.add(v), updates[0]).div(updates.length);
            return c.add(update).value;
        });
    }
}
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1