本記事は教師なし学習において代表的な「k-means」に関するもの。
初学者
・PythonとScikit-learnを用いてk-meansを実装したい。
中級者
・k-meansの発展、k-means++までしっかりと理解したい。
・エルボー法を用いてクラスタ数kの最適化を行いたい。
初学者
アルゴリズムの深い理解は必要なく、とりあえずk-mean法の実装だけを目的とした方向けの内容。
k-means法のアルゴリズム(視覚的理解)
k-means法では、下記手順に従ってデータのクラスタリングを行う。
初めにクラスタ数Nを決定する。(今回の場合 N=2 とする)
①クラスタ数Nに従い、データにクラスタをランダムに付与する。
②各クラスタの重心を求める。
③各データにおいて最も距離が近い重心のクラスタに変更する。
④各クラスタの重心が収束するまで②、③の処理を繰り返す。
k-means法の実装(ライブラリScikit-learnを使用)
scikit-learn とはPythonのオープンソース機械学習ライブラリであり、比較的簡単にk-means法の実装が可能。
コードは下記の通り。
JyupyterNotebook
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
# Irisデータセットを読み込む
iris = load_iris()
iris_hstack = np.hstack((iris.data, iris.target.reshape(-1,1)))
x_data = iris_hstack[iris_hstack[:,4] != 2, :][:, [0, 2]] # 今回は'setosa', 'versicolor'のデータのみ使う
# クラスタリング前の散布図を描画
fig, ax = plt.subplots(figsize=(10, 4), dpi=100, ncols=2)
ax[0].scatter(x_data[:,0],x_data[:,1])
ax[0].set_title('クラスタリング前')
ax[0].set_xlabel('Sepal Width')
ax[0].set_ylabel('Petal Width')
# ハイパーパラメータのクラスタ数を3に設定(再現性を持たせるため、random_state100固定)
km = KMeans(n_clusters=3, init='random', n_init=10, random_state=100)
# Kmeansを実行
y_km = km.fit_predict(x_data)
# クラスタ1を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 0, 0], x_data[y_km == 0,1], s=50, edgecolor='black', marker='s', label='cluster 1')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 0, 0]), np.mean(x_data[y_km == 0, 1]), marker='x', markersize=15, color='red')
# クラスタ2を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 1, 0], x_data[y_km == 1,1], s=50, edgecolor='black', marker='o', label='cluster 2')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 1, 0]), np.mean(x_data[y_km == 1, 1]), marker='x', markersize=15, color='red')
# クラスタ3を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 2, 0], x_data[y_km == 2, 1], s=50, edgecolor='black', marker='v', label='cluster 2')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 2, 0]), np.mean(x_data[y_km == 2, 1]), marker='x', markersize=15, color='red')
ax[1].set_title('クラスタリング後(k-menas法)')
ax[1].set_xlabel('Sepal Width')
ax[1].set_ylabel('Petal Width')
plt.show()
# グラフを保存
fig.savefig("k-means_graph.png")
長ったらしくコードが書かれているが、実際にk-means法を実装しているのは19~21行目のみ(それ以外のコードは前処理やグラフ表示に関するもの)
2021/01/24 今後追記予定