LoginSignup
1
0

More than 3 years have passed since last update.

k-means法を理解する

Last updated at Posted at 2021-01-25

本記事は教師なし学習において代表的な「k-means」に関するもの。

初学者
・PythonとScikit-learnを用いてk-meansを実装したい。
中級者
・k-meansの発展、k-means++までしっかりと理解したい。
・エルボー法を用いてクラスタ数kの最適化を行いたい。

初学者

アルゴリズムの深い理解は必要なく、とりあえずk-mean法の実装だけを目的とした方向けの内容。

k-means法のアルゴリズム(視覚的理解)

k-means法では、下記手順に従ってデータのクラスタリングを行う。
初めにクラスタ数Nを決定する。(今回の場合 N=2 とする)
 ①クラスタ数Nに従い、データにクラスタをランダムに付与する。
 ②各クラスタの重心を求める。
 ③各データにおいて最も距離が近い重心のクラスタに変更する。
 ④各クラスタの重心が収束するまで②、③の処理を繰り返す。
k-means1.png

k-means法の実装(ライブラリScikit-learnを使用)

scikit-learn とはPythonのオープンソース機械学習ライブラリであり、比較的簡単にk-means法の実装が可能。

コードは下記の通り。

JyupyterNotebook
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.cluster import KMeans

# Irisデータセットを読み込む
iris = load_iris()
iris_hstack = np.hstack((iris.data, iris.target.reshape(-1,1)))
x_data = iris_hstack[iris_hstack[:,4] != 2, :][:, [0, 2]] # 今回は'setosa', 'versicolor'のデータのみ使う

# クラスタリング前の散布図を描画
fig, ax = plt.subplots(figsize=(10, 4), dpi=100, ncols=2)
ax[0].scatter(x_data[:,0],x_data[:,1])
ax[0].set_title('クラスタリング前')
ax[0].set_xlabel('Sepal Width')
ax[0].set_ylabel('Petal Width')

# ハイパーパラメータのクラスタ数を3に設定(再現性を持たせるため、random_state100固定)
km = KMeans(n_clusters=3, init='random', n_init=10, random_state=100)
# Kmeansを実行
y_km = km.fit_predict(x_data)

# クラスタ1を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 0, 0], x_data[y_km == 0,1], s=50, edgecolor='black', marker='s', label='cluster 1')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 0, 0]), np.mean(x_data[y_km == 0, 1]), marker='x', markersize=15, color='red')

# クラスタ2を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 1, 0], x_data[y_km == 1,1], s=50, edgecolor='black', marker='o', label='cluster 2')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 1, 0]), np.mean(x_data[y_km == 1, 1]), marker='x', markersize=15, color='red')

# クラスタ3を作成
# 散布図の作成
ax[1].scatter(x_data[y_km == 2, 0], x_data[y_km == 2, 1], s=50, edgecolor='black', marker='v', label='cluster 2')
# クラスタの中心を作成
ax[1].plot(np.mean(x_data[y_km == 2, 0]), np.mean(x_data[y_km == 2, 1]), marker='x', markersize=15, color='red')

ax[1].set_title('クラスタリング後(k-menas法)')
ax[1].set_xlabel('Sepal Width')
ax[1].set_ylabel('Petal Width')
plt.show()

# グラフを保存
fig.savefig("k-means_graph.png")

k-means_graph.png

長ったらしくコードが書かれているが、実際にk-means法を実装しているのは19~21行目のみ(それ以外のコードは前処理やグラフ表示に関するもの)

2021/01/24 今後追記予定

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0