5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pythonでk-means法をやってみる

Posted at

クラスタリング手法の中でもポピュラーなK-meansについて勉強する機会があったので、今回はPythonを用いてscikit-learnは用いずに実装してみました。が、当然の事ながら精度に関しては当然scikit-learn様が圧勝なので、お勉強の確認的な意味合いが強いです。それでも興味がある方は以下をどうぞ。
※numpyは使っていますが、そこはどうかご容赦ください。

環境

  • Python 3.6.5
  • Mac 10.13.4

K-means法

ご存知の方も多いと思いますが、K-means法は非階層型のクラスタリング手法です。対象データを任意のK個のクラスタに分類する最も単純で基本的なクラスタリング手法と言っても過言ではないでしょう。こちらのサイトが実際のクラスタリングの様子も視覚化されていて非常に分かりやすいです。

プログラム

こちらが今回作成したK-means法のプログラム本体です。

k_means.py
import numpy as np

class KMEANS:

    def __init__(self):
        self.reps=[]
        self.dists=[]
        self.clusters=[]
        self.keep_flag=True

    # c_data: クラスタリング対象データ
    # k: クラスタ数
    def Clustering(self, c_datas, k):
        # 代表点を初期化
        self.RepInit(c_datas, k)

        while self.keep_flag:
            # 代表点と点の距離を計算
            self.ClusterDist(c_datas)

            # 所属クラスタを更新
            self.ClusterUpdate()

            # 代表点を更新
            self.RepUpdate(c_datas)

    # クラスタ代表点の初期化 + 代表点と各点の距離と所属クラスタを格納するリストの初期化
    def RepInit(self, c_datas, k):
        for i in range(k):
            self.reps.append(list(np.random.rand(len(c_datas[0]))))
        self.dists = [[-1 for j in range(len(self.reps))] for i in range(len(c_datas))]
        self.clusters=[-1 for i in range(len(c_datas))]

    # 代表点と点の距離を計算
    def ClusterDist(self, c_datas):
        for (i, c_data) in enumerate(c_datas):
            for (j, rep) in enumerate(self.reps):
                # 各点の代表点との距離を計算
                self.dists[i][j] = self.Dist(c_data, rep)

    # 二点間の距離を計算
    def Dist(self, x1, x2):
        return np.linalg.norm(np.array(x1) - np.array(x2))

    # 所属クラスタ更新
    def ClusterUpdate(self):
        flag=False
        for (i, dist) in enumerate(self.dists):
            # クラスタ更新があった場合はwhileループのフラグをTrueに維持
            if self.clusters[i] != np.argmin(dist):
                flag=True
            # 距離のリストから最小値の引数を得る
            self.clusters[i] = np.argmin(dist)
        self.keep_flag=flag

     # クラスタの代表点を更新
    def RepUpdate(self, c_datas):
        for c_num in range(len(self.reps)):
            cluster_points=[]
            for i, (cluster, c_data) in enumerate(zip(self.clusters,c_datas)):
                if cluster == c_num:
                    # clauster_pointsにc_numクラスタの点を追加
                    cluster_points.append(c_datas[i])
            if len(cluster_points) is 0:
                cluster_points.append(self.reps[c_num])
            # 点の平均を求め代表点を更新
            self.reps[c_num] = list(np.array(cluster_points).mean(axis=0))

初めにRepInit()でクラスタの代表点を初期化しています。ひとまず今回はnumpyの関数を用いて0~1の間の範囲の乱数を生成することによってクラスタ代表点の初期値としています。self.repsの中身は[[クラスタ0(以下c0)の代表点],[クラスタ1(以下c1)の代表点],...,[クラスタn(以下cn)の代表点]]となっています。要素一つがそれぞれのクラスタに対応しています。
このときついでにクラスタの代表点と各データの距離を入れるリストや、所属クラスタを記録するリストも初期化しています。self.dists[[[データ1のc0との距離],[データ1のc1との距離],...,[データ1のcnとの距離]],[[データ2の]]]となっています。self.clustersに関しては、データごとの所属クラスタがリストです。
この次からK-meansのクラスタを求めるwhileループに入っていきます。
ClusterDistで各点と代表点の距離を計算を行います。今回距離の計算には、numpyの関数を用いています。まあ計算自体は簡単ですが、何も全部コードで書かなくてもね?とりあえずK-meansの流れを実際にコードとして書いてみるのが主題ですから。
代表点との距離を求めたので、ClusterUpdateで所属クラスタの更新を行います。これもnumpyでリスト内の最小値のインデックスを返すというドンピシャなものがあったので採用。また、この際にクラスタの更新が起こったかという点に関しても判定を行なっています。全てのデータにおいて前回と同じクラスタが選択された場合には、whileループの脱出フラグがたちます。
最後にRepUpdateにて代表点の更新を行なっています。各クラスタごとの所属データを一時的に保管し、平均を求めています。一応クラスタに所属するデータが存在しなかった場合に元の代表点を保持するようにしていますが、本当ならランダムに飛ばした方が良いんですかね。というか、所属しない状況が起こりうるのは最初の一回目くらいでしょうか。

実行

test.py
from k_means import KMEANS

test_data=[[1,1],
[5,5],
[80,7],
[6,6],
[90,5],
[77,10],
[88,90],
[77,80],
[66,100]]

clustering=KMEANS()

clustering.Clustering(test_data,3)

print(clustering.clusters)
$ python3 test.py 
[0, 0, 2, 0, 2, 2, 1, 1, 1]
$ python3 test.py 
[2, 0, 1, 0, 1, 1, 1, 1, 1]

データはとりあえず分かりやすくなるよう偏らせています。
K-meansは代表点の初期値によって結果が毎回異なるので、1回目と2回目で結果が違うのは問題ありません。が、2回目はちょっと酷いような……。そもそも初期値の取り方がデータの大きさに対して小さすぎるのが原因でしょうか。
scikit-learnの内部ではどんな感じで初期値を決めているのか少し気になります。加えてscikit-learnは初期値を変えてデフォルトで300回繰り返し良い数値を探しているらしいですね。
というわけで、みなさんPythonでK-means方をするときはscikit-learnを使いましょう!

参考サイト

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?