はじめに
「色々な機械学習処理をブラウザ上で試せるサイトを作った」中で実装したモデルの解説の十回目です。
今回はAffinity Propagationの実装について解説します。
デモはこちらから。(TaskをClusteringにして、ModelのAffinity Propagationを選択)
実際のコードはaffinity_propagation.jsにあります。
なお、可視化部分については一切触れません。
概説
更新式について、いくつかパターンが存在するようです。
-
2006年提案手法
別の方の投稿に紹介された更新式です。
この元となった論文は「Frey et al. "Mixture Modeling by Affinity Propagation"(2006)」になるかと思います。
最初はこちらを実装した(__fit
)のですが、上手く動きませんでした。 -
2007年提案手法
調べると、同じ著者による別の手法が存在しました。
論文は「Frey et al. "Clustering by Passing Messages Between Data Points"(2007)」になります。「Fujiwara et al. "Fast Algorithm for Affinity Propagation"(2011)」に同様のアルゴリズムの詳細が載っていましたので、この数式を参考に実装しました(
fit
)。なお、「松下ら "メッセージ集約に基づくAffinity Propagationの高速化"(2019)」にもほぼ同様のアルゴリズムが載っていました。英語が苦手な方はこちらも参考にしてください。
アルゴリズム
「Fujiwara et al. "Fast Algorithm for Affinity Propagation"(2011)」からの引用となります。
計算アルゴリズム部分のみですので、理論が知りたい方は論文を参照してください。
初期化
データ$x_i (i = 1, 2, ..., N)$に対して、零行列$\bf\it R, A \in \mathbb{R}^{N \times N}$を用意します。
また行列$\bf\it S \in \mathbb{R}^{N \times N}$は、適当な類似度を計算する関数$\it{D} \left(x_i, x_j \right)$を用いて
s_{i, j} = \it{D} \left(x_i, x_j \right)
と設定します。
実装では、$\it D$はユークリッド距離の二乗に$-1$を掛けた値を使用しています。ただし、対角成分に関しては非対角成分の最小値としています。
また$0 < \lambda < 1$となる変数を用意します。
更新
\begin{eqnarray}
\rho_{i, j} &=& \left\{ \begin{array}{ll}
s_{i, j} - \max_{k \not= j} \left( a_{i, k} + s_{i, k} \right) & (i \not= j) \\
s_{i, j} - \max_{k \not= j} s_{i, k} & (i = j)
\end{array} \right.
\\
\alpha_{i, j} &=& \left\{ \begin{array}{ll}
\min \left(0, r_{j, j} + \sum_{k \not= i, j} \max \left(0, r_{k, j} \right) \right) & (i \not= j) \\
\sum_{k \not= i} \max \left(0, r_{k, j} \right) & (i = j)
\end{array} \right.
\end{eqnarray}
を計算し、これより
\begin{eqnarray}
r_{i, j} &=& (1 - \lambda) \rho_{i, j} + \lambda r_{i, j}
\\
a_{i, j} &=& (1 - \lambda) \alpha_{i, j} + \lambda a_{i, j}
\end{eqnarray}
と更新します。
推定
データ$x_i$のexemplar(模範)となるデータは、次の通り求めます
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\argmax_{j} \left( r_{i, j} + a_{i, j} \right)
コード
Affinity Propagation
class AffinityPropagation {
// https://qiita.com/daiki_yosky/items/98ce56e37623c369cc60
// https://tjo.hatenablog.com/entry/2014/07/31/190218
constructor() {
this._epoch = 0;
this._l = 0.8;
this._x = [];
this._y = null;
}
get centroidCategories() {
const y = this.predict();
return [...new Set(y)]
}
get centroids() {
return this.centroidCategories.map(i => this._x[i]);
}
get size() {
const y = this.predict();
return new Set(y).size
}
get epoch() {
return this._epoch
}
init(datas) {
this._x = datas;
const n = datas.length;
this._r = Array(n);
this._a = Array(n);
this._ar = Array(n);
this._s = Array(n);
this._as = Array(n);
for (let i = 0; i < n; i++) {
this._r[i] = Array(n).fill(0)
this._a[i] = Array(n).fill(0)
this._ar[i] = Array(n).fill(0)
this._s[i] = Array(n)
this._as[i] = Array(n)
}
this._y = null;
this._epoch = 0;
let min = Infinity
for (let i = 0; i < n; i++) {
for (let j = 0; j < i; j++) {
if (i === j) continue
const d = -this._x[i].reduce((s, v, k) => s + (v - this._x[j][k]) ** 2, 0)
this._s[i][j] = this._s[j][i] = d;
this._as[i][j] = this._as[j][i] = d;
min = Math.min(min, d)
}
}
for (let i = 0; i < n; i++) {
this._s[i][i] = this._as[i][i] = min;
}
}
fit() {
// Frey. et al. "Clustering by Passing Messages Between Data Points" (2007)
// "Fast Algorithm for Affinity Propagation"
const x = this._x;
const n = x.length;
const l = this._l
for (let i = 0; i < n; i++) {
for (let k = 0; k < n; k++) {
let m = -Infinity
const ss = (i === k) ? this._s[i] : this._as[i];
for (let kd = 0; kd < n; kd++) {
if (k === kd) continue
m = Math.max(m, ss[kd])
}
this._r[i][k] = l * this._r[i][k] + (1 - l) * (this._s[i][k] - m)
}
}
for (let i = 0; i < n; i++) {
for (let k = 0; k < n; k++) {
let s = (i === k) ? 0 : this._r[k][k];
for (let id = 0; id < n; id++) {
if (id !== i && id !== k) {
s += Math.max(0, this._r[id][k])
}
}
if (i !== k) s = Math.min(0, s)
const aik = l * this._a[i][k] + (1 - l) * s;
this._a[i][k] = aik;
this._ar[i][k] = aik + this._r[i][k];
this._as[i][k] = aik + this._s[i][k];
}
}
this._y = null;
this._epoch++;
}
__fit() {
// Frey. et al. "Mixture Modeling by Affinity Propagation" (2006)
const x = this._x;
const n = x.length;
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
let s = 0;
for (let k = 0; k < n; k++) {
if (k === j) continue
s += this._a[i][k] * this._s[i][k];
}
this._r[i][j] = this._s[i][j] / s
}
}
for (let i = 0; i < n; i++) {
let p = 1;
for (let k = 0; k < n; k++) {
if (k === i) continue
p *= (1 + this._r[k][i]);
}
this._a[i][i] = p - 1
this._ar[i][i] = this._a[i][i] + this._r[i][i]
for (let j = 0; j < n; j++) {
if (i === j) continue;
p = 1 / this._r[i][i] - 1;
for (let k = 0; k < n; k++) {
if (k === i || k === j) continue
p *= 1 / (1 + this._r[k][i])
}
this._a[i][j] = 1 / (1 + p)
this._ar[i][j] = this._a[i][j] + this._r[i][j]
}
}
this._y = null;
this._epoch++;
}
predict() {
if (!this._y) {
this._y = [];
const n = this._x.length;
for (let i = 0; i < n; i++) {
let max_v = -Infinity;
let max_i = -1;
for (let j = 0; j < n; j++) {
const v = this._ar[i][j];
if (max_v < v) {
max_v = v;
max_i = j;
}
}
this._y.push(max_i);
}
}
return this._y;
}
}