より高速なKModes, KPrototypesを実装
こちらのコードを参考に、より高速化したKModes、KPrototypesを実装しました。距離尺度については、カテゴリ特徴量はハミング距離、数値特徴量はユークリッド距離のみを実装しています。
高速化
処理時間がかかる部分をCで置き換えました。KMeansと同じく処理の重い箇所は、
- 入力データとクラスター中心との間の距離計算(距離行列)
- クラスター中心の更新
の二つです。カテゴリ特徴量の場合、前者はSIMDによる置き換えも容易なので、非常に簡易的なものですがそちらも実装しました。数値特徴量では、scikit-learnのKMeansでも用いられているようにGEMMを活用することで、かなりの高速化が実現できます。おそらくカテゴリ特徴量もGEMM実装のいろいろなテクニックを適用できればもっと速くなるかと思うのですが、難易度が高過ぎたため、着手していません。またSIMD実装は今回実験したデータセットでは意味がありませんでした。
後者のクラスタ更新は、メモリへの連続アクセス方法が思い浮かばなかったため、SIMD対応はしていません。ただ、カテゴリ数が多い場合には高速になるように(多分)キャッシュを意識した実装になっているようにしてみました。
上記のC言語で実装した関数の一部は、入力データの型と特徴量数に応じて自動的に内部で生成しています。特徴量数が16/32未満ではSIMDを用いる意味がないため、単純な実装を呼び出します。それぞれ以下のファイル内の処理を確認してみてください。
- 入力データとクラスター中心との間の距離計算(距離行列)
-
dist_mat_256_hamming.c
(256は列数。入力データの形状により異なる)- 入力データ
X
(NxK)とクラスタ中心C
(MxK)を元に、距離行列D
(NxM)を計算する - 同じく
dist_vec_256_hamming.c
もあるがこちらは、入力データX
(NxK)と1個のクラスタ中心c
(K)を元に、距離ベクトルd
(N)を計算する
- 入力データ
-
- クラスター中心の更新
common_funcs.c
使い方
1. 入力データはnp.ndarray
のみ
2. KModesの場合はuint8
, uint16
のみ。Kprototypesの場合はfloat
, double
も受け付けるが、内部でカテゴリ特徴量部分はuint8
or uint16
に変換される。
KModes
KModesが受け付ける入力データの型はuint8
、 uint16
のみです。uint16
で6万以上のカテゴリを表現でき、それ以上のカテゴリ数があったとしても、解釈性から言って意味のないものと考えたためです(100以上の時点で十分だと思いますが)。入力データは事前にsklearn.preprocessing.LabelEncoder
などによる整数値への置換が必要です。
import numpy as np
from FasterKModes import FasterKModes
N = 1100 # 行数
K = 31 # 列数
C = 8 # クラスタ数
X = np.random.randint(0, 256, (N, K)).astype(np.uint8)
X_train = X[:1000,:]
X_test = X[1000:,:]
fKModes = FasterKModes(n_clusters=C, init="random", n_init=10)
fKModes.fit(X_train)
labels = fKModes.predict(X_test)
KPrototypes
KPrototypesはuint8
、uin16
に加え、float
、double
を受け付けます。KPrototypesは入力X
とカテゴリ特徴量がカラムのどの位置にあるかを表すlist
を引数とします。
float
かdouble
を入力データの型とした際に、カテゴリ特徴量列は最大値に応じて自動的にuint8
かuin16
に変換される(小数点以下が無視される)ため、注意が必要です。また、最大値が'uint16'上限を超えている場合にはエラーが出ます。
import numpy as np
import FasterKPrototypes as fkp
N = 1100 # 行数
K = 31 # 列数
C = 8 # クラスタ数
X = np.random.randint(0, 256, (N, K)).astype(np.uint8)
X_train = X[:1000,:]
X_test = X[1000:,:]
categorical_features = [1, 3, 5, 7, 9]
fKProto = fkp.FasterKPrototypes(n_clusters=C, init="random", n_init=10)
fKProto.fit(X_train, categorical_feature_indices=categorical_features)
labels = fKProto.predict(X_test)
精度・速度比較
オリジナルの実装と精度と速度を比較しました。データは、sklearn.datasets.make_classification
で生成したデータX
を、LightGBMで学習し、そのLeafIndexを出力させた出力データX'
を用いました。比較の際に変化させた点は、
- 行数:生成データ
X
の行数 - 列数:LightGBMの
n_estimators
。生成データX
の列数ではない - 最大値:
X'
の取り得る最大値。LightGBMのmax_depth
からpow(2, max_depth)
で計算される=LeafNodeの取り得る最大の数に相当する - クラスタ数:KModesのクラスタ数
精度
精度(cost)は、各データ点とそれらに最も近いクラスター中心の距離までの総和です。完全に同一の値にすることは困難なため、同じ入力データ・同じクラスタ数に対して、同程度の精度が出ればよいものとしています。
下記の表がその結果です。初期値はランダムを用いています。おおむね同程度の結果が出ているかと思います。
KModesのコスト比較
クラスタ数 | (1000, 8) | (10000, 8) | (100000, 8) | |||
---|---|---|---|---|---|---|
元実装 | 再実装 | 元実装 | 再実装 | 元実装 | 再実装 | |
2 | 5,466 | 5,466 | 56,118 | 55,937 | 554,695 | 554,695 |
4 | 4,353 | 4,365 | 43,956 | 44,325 | 448,733 | 445,752 |
8 | 3,105 | 3,086 | 32,222 | 33,052 | 323,439 | 323,521 |
16 | 2,264 | 2,332 | 23,035 | 24,451 | 241,809 | 243,484 |
32 | 1,632 | 1,724 | 17,443 | 18,388 | 168,879 | 182,752 |
64 | 1,170 | 1,147 | 12,065 | 12,261 | 115,768 | 119,523 |
128 | 712 | 680 | 7,835 | 7,723 | 74,772 | 78,843 |
256 | 402 | 328 | 4,992 | 4,523 | 48,861 | 48,304 |
KPrototypesのコスト比較
クラスタ数16以上はKprototypesでは実行できなかったため、スキップしています。
https://github.com/nicodv/kmodes#:~:text=Q%3A%20I%27m%20getting%20the%20following%20error%3A%20%22ValueError%3A%20Clustering%20algorithm%20could%20not%20initialize.%20Consider%20assigning%20the%20initial%20clusters%20manually.%22
クラスタ数 | (1000, 8) | (10000, 8) | (100000, 8) | |||
---|---|---|---|---|---|---|
元実装 | 再実装 | 元実装 | 再実装 | 元実装 | 再実装 | |
2 | 46,108 | 46,108 | 460,290 | 460,290 | 6,021,514 | 6,021,515 |
4 | 22,208 | 21,901 | 235,031 | 233,540 | 1,703,346 | 1,697,376 |
8 | 14,645 | 14,654 | 88,984 | 84,085 | 1,277,582 | 1,244,299 |
速度
行数、列数、最大値(カテゴリのユニーク数)、クラスタ数をいろいろ変化させてオリジナル実装との速度を比較します。(もっと行数を増やした場合も比較したかったのですが、オリジナル実装の4万行以上の実行は時間がかかりすぎたので、スキップしています。)
KModes
その他
KPrototype plus (kpplus)
実験を終わらせてから気づいたのリポジトリです。オリジナル実装にNumbaによる最適化を施しているようです。
距離尺度の拡張
現状ユークリッド距離、ハミング距離しか実装ができていないため、いくつか追加していきたいと考えています。
以上