1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

x-means法のJavaScriptによる実装

Last updated at Posted at 2020-10-18

はじめに

色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の二回目です。

今回はx-means法の実装について解説します。

デモはこちらから。(TaskをClusteringにして、ModelのX-Meansを選択)
実際のコードはxmeans.jsにあります。

なお、可視化部分については一切触れません。

概説

アルゴリズムなどについては既に良い記事がありますので、こちらも参考にしてください。

処理の流れ

fit関数により分割処理を行います。

  1. 初期分割状態の生成
    • 一回目の呼び出しの場合は、k-means法により初期クラスタ数に分割(_split_cluster)し、各クラスタ情報を取得します
    • 二回目以降の場合は、自分自身の状態から各クラスタ情報を取得します
  2. 存在する各クラスタに対してk-means法による分割(_split_cluster)を実施します
    • 分割後のBICが小さくなった場合は、その分割を採用します
    • 逆に大きくなった場合は、その分割は行いません
  3. 指定されただけ、あるいは分割できなくなるまで2.を繰り返します

各クラスタ情報は以下のデータで、BIC計算のために使用するなどします。

  • size : データ数
  • cols : 特徴量数
  • data : データ(size×colsなる配列の配列)
  • cov : 分散共分散行列(cols×colsなる行列)
  • centroid : セントロイド
  • llh : 対数尤度
  • bic : BIC

この情報を初回呼び出し時およびクラスタ分割時に取得(_create_clusters)しています。

また内部で使用するk-means法の処理は、先日投稿した記事で定義したKMeansModelで実行します。

行列演算

各種行列演算は手作り数学ライブラリmath.jsの中で定義してあるMatrixで実行します。

コード中で使用している処理はそれぞれ以下の通りです。

  • プロパティ
    • t : 転置行列を返す
    • value : 内部で持っている一次元配列を返す
    • rows : 行数を返す
    • cols : 列数を返す
  • 関数
    • cov : 分散共分散行列を返す
    • det : 行列式を返す
    • dot : 行列積を返す
    • mean : 平均値を返す
    • row : 行ベクトルを(1行の行列として)返す
    • sub : 減算する(inplace処理)

正規分布の累積分布関数

BICの計算において正規分布の累積分布関数が必要になります。

このサイトによると、標準正規分布$f(x)$についてその累積分布関数$F(x)$は以下の近似式が知られているそうです。

f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}}
\\
のとき
\\
F(x) = \int_{-\infty}^x f(x) dx \simeq \frac{1}{1 + e^{-1.7 x}}

実装では、この近似式を用いて正規分布の累積分布関数を計算しています。

コード

class XMeans {
	// https://qiita.com/deaikei/items/8615362d320c76e2ce0b
	// https://www.jstage.jst.go.jp/article/jappstat1971/29/3/29_3_141/_pdf
	constructor() {
		this._centroids = [];
		this._init_k = 2;
	}

	get centroids() {
		return this._centroids;
	}

	get size() {
		return this._centroids.length;
	}

	_distance(a, b) {
		return Math.sqrt(a.reduce((acc, v, i) => acc + (v - b[i]) ** 2, 0));
	}

	clear() {
		this._centroids = [];
	}

	fit(datas, iterations = -1) {
		let clusters = null;
		if (this._centroids.length === 0) {
			clusters = this._split_cluster(datas, this._init_k);
			iterations--
		} else {
			clusters = this._create_clusters(this, datas);
		}
		const centers = [];

		while (clusters.length > 0 && (iterations < 0 || iterations-- > 0)) {
			const new_clusters = [];
			while (clusters.length > 0) {
				const c = clusters.shift();
				if (c.size <= 3) {
					centers.push(c.centroid)
					continue
				}
				const [c1, c2] = this._split_cluster(c.data);
				const beta = Math.sqrt(c1.centroid.reduce((s, v, i) => s + (v - c2.centroid[i]) ** 2, 0) / (c1.cov.det() + c2.cov.det()));
				// http://marui.hatenablog.com/entry/20110516/1305520406
				const norm_cdf = 1 / (1 + Math.exp(-1.7 * beta))
				const alpha = 0.5 / norm_cdf

				const df = c.cols * (c.cols + 3) / 2
				const bic = -2 * (c.size * Math.log(alpha) + c1.llh + c2.llh) + 2 * df * Math.log(c.size);

				if (bic < c.bic) {
					new_clusters.push(c1, c2)
				} else {
					centers.push(c.centroid)
				}
			}
			clusters = new_clusters;
		}
		if (clusters.length > 0) {
			centers.push(...clusters.map(c => c.centroid))
		}
		this._centroids = centers;
	}

	_split_cluster(datas, k = 2) {
		const kmeans = new KMeansModel();
		for (let i = 0; i < k; i++) {
			kmeans.add(datas);
		}
		while (kmeans.fit(datas) > 0);
		return this._create_clusters(kmeans, datas);
	}

	_create_clusters(model, datas) {
		const k = model.size;
		const p = model.predict(datas);
		const ds = [];
		for (let i = 0; i < k; ds[i++] = []);
		datas.forEach((d, i) => ds[p[i]].push(d));
		const clusters = [];
		for (let i = 0; i < k; i++) {
			const mat = Matrix.fromArray(ds[i]);
			const cov = mat.cov();
			const invcov = cov.inv()
			const mean = mat.mean(0);
			const cc = Math.log(1 / Math.sqrt((2 * Math.PI) ** mat.cols * cov.det()))
			let llh = cc * mat.rows;
			for (let j = 0; j < mat.rows; j++) {
				const r = mat.row(j);
				r.sub(mean);
				llh -= r.dot(invcov).dot(r.t).value[0] / 2
			}
			const df = mat.cols * (mat.cols + 3) / 2
			clusters[i] = {
				size: ds[i].length,
				cols: mat.cols,
				data: ds[i],
				cov: cov,
				centroid: model.centroids[i],
				llh: llh,
				bic: -2 * llh + df * Math.log(ds[i].length)
			}
		}
		return clusters;
	}

	predict(datas) {
		if (this._centroids.length == 0) {
			return;
		}
		return datas.map(value => {
			return argmin(this._centroids, v => this._distance(value, v));
		});
	}
}

さいごに

コードを読み直していると、関数名にキャメルケースとスネークケースが混ざっているのがかなり気になります。
自分で書いておいてあれですが。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?