Posted at

Go初心者がGoでk-meansを実装してみた

More than 1 year has passed since last update.

Goの勉強として、クラスタリングの手法の1つであるk-meansを実装してみました。今回は簡単のため、2次元平面上に点の集まりをクラスタリングしました。

Go初心者なので変な書き方があるかと思いますが、そのときはコメントでご指摘お願いします。


データ作り

こんな感じのデータを作りました。3つのクラスターに分かれていることが目視で確認できます。

グラフはGnuplotで作りました。Go言語のプログラムでデータをファイルで出力して、それをGnuplotで読み込ませました。データは乱数を使って作りました。データづくりのコードは次の通りです。


make_data.go

package main

import (
"fmt"
"os"
"math/rand"
)

type XY struct{ X, Y float64}

func main() {
// 乱数初期化
rand.Seed(int64(0))

// 出力ファイルオープン
file, err := os.Create("points.dat")
if err != nil {
panic(err)
}
defer file.Close()

// データ作って、ファイルに出力
clusterSize := 100
writeData(file, createData(XY{0, 0}, clusterSize))
writeData(file, createData(XY{-4, 4}, clusterSize))
writeData(file, createData(XY{4, 4}, clusterSize))
}

// ファイルに出力(Gnuplotのデータ形式)
func writeData(file *os.File, data []XY) {
for _, xy := range data {
file.WriteString(fmt.Sprintf("%f %f\n", xy.X, xy.Y))
}
}

func createData(center XY, num int) []XY {
data := make([]XY, num)
for i := range data {
data[i].X = rand.NormFloat64() + center.X
data[i].Y = rand.NormFloat64() + center.Y
}
return data
}



k-means

上で作ったデータを使ってk-meansでクラスタリングします。実装したコードは次の通りです。計算効率はあまり考えずに実装しました。


k_means.go

package main

import (
"fmt"
"os"
"bufio"
"strings"
"strconv"
"math"
"math/rand"
)

type XY struct { X, Y float64 }

type Cluster []XY

func main() {
rand.Seed(int64(114514))
numCluster := 3

clusters := readData("points.dat", numCluster)

for n := 0; n < 10; n++ {
// 重心計算
centers := make([]XY, 3)
for i := 0; i < numCluster; i++ {
centers[i] = centerOfCluster(clusters[i])
}
// クラスターの更新
clusters = updateClusters(clusters, centers, numCluster)
}

writeData(clusters, "result.dat", numCluster)
}

// ファイルからデータ読み込み、ついでに初期クラスタに分ける
func readData(fileName string, numCluster int) []Cluster {
file, err := os.Open(fileName)
if err != nil {
panic(err)
}
defer file.Close()

clusters := make([]Cluster, numCluster)

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
xy := strings.SplitN(line, " ", 2)

x, _ := strconv.ParseFloat(xy[0], 64)
y, _ := strconv.ParseFloat(xy[1], 64)
p := XY{x, y}

group := rand.Intn(numCluster)
clusters[group] = append(clusters[group], p)
}

return clusters
}

// Gnuplotのデータ形式で書き込み
// 1クラスターのデータを1ブロックに
func writeData(clusters []Cluster, fileName string, numCluster int) {
file, err := os.Create(fileName)
if err != nil {
panic(err)
}
defer file.Close()

for i := 0; i < numCluster; i++ {
for _, p := range clusters[i] {
file.WriteString(fmt.Sprintf("%f %f\n", p.X, p.Y))
}
file.WriteString("\n")
}
}

// クラスターの重心を計算
func centerOfCluster(cluster Cluster) XY {
clusterSize := len(cluster)
var sumX float64 = 0
var sumY float64 = 0

for _, p := range cluster {
sumX += p.X
sumY += p.Y
}

cX := sumX / float64(clusterSize)
cY := sumY / float64(clusterSize)

return XY{cX, cY}
}

// 2点間の距離
func distance(p XY, q XY) float64 {
d2 := math.Pow(p.X - q.X, 2.0) + math.Pow(p.Y - q.Y, 2.0)
return math.Sqrt(d2)
}

// クラスターの更新
func updateClusters(clusters []Cluster, centers []XY, numCluster int) []Cluster {
newClusters := make([]Cluster, numCluster)

for i := 0; i < numCluster; i++ {
for _, p := range clusters[i] {
group := 0
minDistance := distance(p, centers[0])

// 距離が最も近いクラスターを探す
for j := 1; j < numCluster; j++ {
d := distance(p, centers[j])
if d < minDistance {
group = j
minDistance = d
}
}

newClusters[group] = append(newClusters[group], p)
}
}

return newClusters
}


本当は収束判定をしてk-meansの計算を打ち切るかどうかを決めるべきですが、今回は簡略化のため、回数を決め打ちしてk-meansの計算をしています。

実装してて思ったのは、typeで気軽に型に別名を付けられるのは便利ということですね。

さて、上のコードを使ってクラスタリングした結果はこのようになりました。

うまくクラスタリングできていますね。めでたし、めでたし。