はじめに
「色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の四回目です。
今回はMeanShiftの実装について解説します。ほぼ解説はしていませんが。
デモはこちらから。(TaskをClusteringにして、ModelのMean Shiftを選択)
実際のコードはmean_shift.jsにあります。
なお、可視化部分については一切触れません。
概説
実装はこのサイトを参考にしました。
この数式をだいたいそのままコードに書き起こしています。
なので具体的な流れはそちらを参照してもらい、実装で異なっている部分のみ説明します。
数式との相違点
二点、異なる部分があります。
-
更新式
上記サイトでは
\vec{\nabla} f_K \left(\vec{x}\right) = \frac{1}{h^{2}} f_G \left(\vec{x}\right) \vec{s}_G \left(\vec{x}\right)
となっていますが、実装では
\vec{\nabla} f_K \left(\vec{x}\right) = \vec{s}_G \left(\vec{x}\right)
としています。
fd
中のreturn
文のコメントアウトしてある方が数式と一致する更新式です。なお、なぜこのような変更を行ったのかは、記憶のかなたです。とりあえず動いているようには見えます。
-
体積の計算
二次元(円の面積)に固定されています。
ただし体積は$f_G \left(\vec{x}\right)$の計算にのみ使用するため、動作には影響はありません。正しい更新式を用いる場合は、超球の体積計算式を使用するように変更する必要があります。
超球の体積はWikipediaより、
V_n (r) = \frac{n^{\pi / 2}}{\Gamma \left(\frac{n}{2} + 1\right)} r^n
であり、またガンマ関数$\Gamma(x)$は同じくWikipediaより、自然数$n$について同じく
\begin{eqnarray} \Gamma(n + 1) &=& n! \\ \Gamma(n + \frac{1}{2}) &=& \frac{(2n - 1)!!}{2^n}\sqrt{\pi} \end{eqnarray}
が成り立ちます。
ただし、数式をそのまま実装すると桁数が大変なことになるので、全体を対数でくくったほうが良いです。超球の体積計算は、K近傍法に基づく密度推定手法に実装したものがあります。
この中のilogv
が超球の体積の逆数の対数を取ったものになります。つまり$\log{\frac{1}{V_n(r)}}$です。_logGamma(z) { let x = 0 if (Number.isInteger(z)) { for (let i = 2; i < z; i++) { x += Math.log(i) } } else { const n = z - 0.5 x = Math.log(Math.sqrt(Math.PI)) - Math.log(2) * n for (let i = 2 * n - 1; i > 0; i -= 2) { x += Math.log(i) } } return x } predict(data) { const ps = this._near_points(data) const r = ps[ps.length - 1].d const d = data.length const ilogv = this._logGamma(d / 2 + 1) - d / 2 * Math.log(Math.PI) - d * Math.log(r) return Math.exp(ilogv) * this.k / this._p.length }
推論
推論は、自身を中心とした超球の位置(this._centroids
)によって決定します。
同一のクラスタとする条件を、中心が厳密に一致する、とするとうまくいかない場合が多いので、一定の誤差を許容するようにします。
これをthis._threshold
より設定し、この距離よりも近い場合は同一のクラスタとするようにしました。
コード
class MeanShift {
// see http://seiya-kumada.blogspot.com/2013/05/mean-shift.html
// see http://takashiijiri.com/study/ImgProc/MeanShift.htm
constructor(h, threshold) {
this._x = null;
this._centroids = null;
this._h = h;
this._threshold = threshold;
this._categories = 0;
}
get categories() {
return this._categories;
}
get h() {
return this._h;
}
set h(value) {
this._h = value;
}
set threshold(value) {
this._threshold = value;
}
_distance(a, b) {
return Math.sqrt(a.reduce((s, v, i) => s + (v - b[i]) ** 2, 0))
}
init(data) {
this._x = data;
this._centroids = this._x.map(v => [].concat(v));
}
predict() {
this._categories = 0;
const p = []
for (let i = 0; i < this._centroids.length; i++) {
let category = i;
for (let k = 0; k < i; k++) {
if (this._distance(this._centroids[i], this._centroids[k]) < this._threshold) {
category = p[k];
break;
}
}
if (category === i) this._categories++;
p[i] = category;
}
return p;
}
fit() {
if (this._centroids.length === 0 || this._x.length === 0) {
return;
}
const d = this._centroids[0].length;
const Vd = Math.PI * (this._h ** 2);
const G = (x, x1) => x.reduce((acc, v, i) => acc + ((v - x1[i]) / this._h) ** 2, 0) <= 1 ? 1 : 0;
const mg = (gvalues) => {
let s = 0;
let v = Array(this._x[0].length).fill(0);
this._x.forEach((p, i) => {
if (gvalues[i]) {
s += gvalues[i];
v = v.map((a, j) => a + p[j] * gvalues[i])
}
});
return v.map((a, i) => a / s);
};
const sg = (x, gvalues) => mg(gvalues).map((v, i) => v - x[i]);
const fg = (gvalues) => {
return gvalues.reduce((acc, v) => acc + v, 0) / (gvalues.length * Vd);
}
const fd = (x) => {
let gvalues = this._x.map(p => G(x, p));
return sg(x, gvalues);
//return sg(x, gvalues).mult(2 / (this._h ** 2) * fg(gvalues));
};
let isChanged = false;
this._centroids = this._centroids.map((c, i) => {
let oldPoint = c;
const v = fd(c);
const newPoint = c.map((a, i) => a + v[i])
isChanged |= oldPoint.some((v, i) => v !== newPoint[i]);
return newPoint;
});
return isChanged;
}
}
さいごに
スコープの近い場所に同じ変数名を使う(最後の方のi
やv
)、という愚行を犯す。