LoginSignup
2
5

More than 5 years have passed since last update.

【機械学習】Goで書いたナイーブなkNNにkmeansの考えを取り込んで高速化してみる

Posted at

概略

kNNのアルゴリズムに手を加えることで予測を高速化していく。kNNは訓練データの読み込み後に事前計算などをしておくことができないアルゴリズムで、予測段階でメイン計算が行われるため、読み込ませる訓練データの量によっては、予測で多くの時間を使うことになる。
ナイーブに実装したkNNアルゴリズムに手を加えることで、予測にかかる時間の変化と、予測精度の変化を実際に見ていく。

動機

最近Goで機械学習アルゴリズムの書き下しを進めている。自分でフルスクラッチで書いていくと、パッケージ使用時に比べてより柔軟に、任意の場所に任意の処理を加えることができるので、良い機会なので実験をしていきたい。
また、機械学習アルゴリズムの中身は決してアンタッチャブルなものではなく、目的に応じて弄っていくのもアリなのだというのを、示していきたい。

実験環境

使用するアルゴリズムはkNNで、コードは、【機械学習】GoでKNN(K近傍法)実装の記事に掲載したものを使用する。この記事では150行からなるirisデータを使用しているが、今回は、時間の計測を行うため、このirisデータを100個繋げたものを使用している。
データのサイズは以下の通り。

  • 行数:15000
  • 特徴量:4

このデータの半分を訓練データ、残りの半分をテストデータとし、7500件のデータに対する予測を早くしていくことになる。

予測高速化の方針

kNNの予測はなぜ遅い?

そもそも、kNNの予測がなぜ遅いのか?
それは、概略の部分でも書いたようにkNNが予測フェーズで初めてメインの計算を行うアルゴリズムで、事前計算ができないからだ。kNNは訓練データとして読み込ませておいた各データポイントと予測対象データとの距離を計算し、距離の近いk個の訓練データのデータポイントのラベルで多数決を行い、予測対象データの予測ラベルとする。
読み込ませる訓練データの量が多ければ距離計算の対象が増えるため、予測に時間がかかるようになる。

アプローチ

予測対象データとの距離を計算するポイントを減らす。ナイーブなkNNだと、予測対象点と、事前に読み込んでおいた訓練データのポイントの全ての距離を計算することになるので、そこを弄る。
手法はいくつか想像できるが、今回は、kmeansの考え方を利用する。具体的には、訓練データを読み込んでおいた段階で、kmeansの仕組みを用いて複数のデータの部分集合のセントロイド点を計算しておく。予測段階では、予測対象データとそれらのセントロイド点との距離を計算し、最も近いセントロイド点に予測対象データを所属したものと判断し、そのセントロイド点に帰属する訓練データのデータポイントとの距離のみを計算し、k個の最近傍の点で多数決を行う。

使用コード

かなりヤッツケコードだが【機械学習】GoでKNN(K近傍法)実装で紹介してるナイーブなkNNのコードを弄って速度改善を狙う。
具体的には、元来、データの格納しか行わなかったfit()メソッドの部分に実質的にkmeansを埋め込み、KNN{}コンストラクタに分類先のクラスタ、クラスタごとのセントロイドポイントを格納するようにする。predictの時にはまず、格納しておいたセントロイドポイントとの距離を計算し、最も距離の近いセントロイドポイントの示すクラスタに属するデータを距離計算対象として近傍点を探してきて多数決を行うようにした。
少し長くなるが、手を加えた後のコードは以下のものになる。具体的な変更点は上に挙げたように、KNN{}コンストラクタに持たせる変数にセントロイドポイント、データの属するクラスタを加えたことと、fit()メソッドにkmeans的なアルゴリズムを組み込んだこと、predict()メソッドで最初にセントロイド点との距離を計算するようにしたことだ。それ以外の点は上記リンク先のものと変わらない。

package main

import (
    "math"
    "sort"
    "os"
    "encoding/csv"
    "io"
    "strconv"
    "fmt"
    "reflect"
    "time"
    "math/rand"
)

func main() {
    //データ読み込み
    irisMatrix := [][]string{}
    iris, err := os.Open("big_iris.csv")
    if err != nil {
        panic(err)
    }
    defer iris.Close()

    reader := csv.NewReader(iris)
    reader.Comma = ','
    reader.LazyQuotes = true
    for {
        record, err := reader.Read()
        if err == io.EOF {
            break
        } else if err != nil {
            panic(err)
        }
        irisMatrix = append(irisMatrix, record)
    }

    //説明変数と被説明変数にデータを分割
    X := [][]float64{}
    Y := []string{}
    for _, data := range irisMatrix {

        //strスライスデータをfloatスライスデータに変換
        temp := []float64{}
        for _, i := range data[:4] {
            parsedValue, err := strconv.ParseFloat(i, 64)
            if err != nil {
                panic(err)
            }
            temp = append(temp, parsedValue)
        }
        //説明変数へ
        X = append(X, temp)

        //被説明変数
        Y = append(Y, data[4])

    }

    //データを訓練データとテストデータに分割
    var (
        trainX [][]float64
        trainY []string
        testX  [][]float64
        testY  []string
    )
    for i, _ := range X {
        if i%2 == 0 {
            trainX = append(trainX, X[i])
            trainY = append(trainY, Y[i])
        } else {
            testX = append(testX, X[i])
            testY = append(testY, Y[i])
        }
    }

    //学習
    knn := KNN{}
    knn.k = 8
    knn.fit(trainX, trainY)
    predicted := knn.predict(testX)

    //正答率確認
    correct := 0
    for i, _ := range predicted {
        if predicted[i] == testY[i] {
            correct += 1
        }
    }
    fmt.Println(correct)
    fmt.Println(len(predicted))
    fmt.Println(float64(correct) / float64(len(predicted)))

}
func Transpose(source [][]float64) [][]float64 {
    var dest [][]float64
    for i := 0; i < len(source[0]); i++ {
        var temp []float64
        for j := 0; j < len(source); j++ {
            temp = append(temp, 0.0)
        }
        dest = append(dest, temp)
    }

    for i := 0; i < len(source); i++ {
        for j := 0; j < len(source[0]); j++ {
            dest[j][i] = source[i][j]
        }
    }
    return dest
}

//スライスに格納されている値が最も小さくなるときにそのインデックスを返す
func ArgMin(target []float64) int {
    var (
        index int
        base  float64
    )
    for i, d := range target {
        if i == 0 {
            index = i
            base = d
        } else {
            if d < base {
                index = i
                base = d
            }
        }

    }
    return index
}

//二つのスライス間のユークリッド距離を計算
func Dist(source, dest []float64) float64 {
    val := 0.0
    for i, _ := range source {
        val += math.Pow(source[i]-dest[i], 2)
    }
    return math.Sqrt(val)
}

//argument sort作成
type Slice struct {
    sort.Interface
    idx []int
}

func (s Slice) Swap(i, j int) {
    s.Interface.Swap(i, j)
    s.idx[i], s.idx[j] = s.idx[j], s.idx[i]
}

func NewSlice(n sort.Interface) *Slice {
    s := &Slice{Interface: n, idx: make([]int, n.Len())}
    for i := range s.idx {
        s.idx[i] = i
    }
    return s
}

func NewFloat64Slice(n []float64) *Slice { return NewSlice(sort.Float64Slice(n)) }

//mapのソート
type Entry struct {
    name  string
    value int
}
type List []Entry

func (l List) Len() int {
    return len(l)
}

func (l List) Swap(i, j int) {
    l[i], l[j] = l[j], l[i]
}

func (l List) Less(i, j int) bool {
    if l[i].value == l[j].value {
        return l[i].name < l[j].name
    } else {
        return l[i].value > l[j].value
    }
}

//スライス中のアイテムの出現頻度をカウント
func Counter(target []string) map[string]int {
    counter := map[string]int{}
    for _, elem := range target {
        counter[elem] += 1
    }
    return counter
}

type KNN struct {
    k            int
    data         [][]float64
    labels       []string
    clusterLabel []int
    centroid     [][]float64
}

func (knn *KNN) fit(X [][]float64, Y []string) {
    //データを読み込む
    knn.data = X
    knn.labels = Y
    //データをランダムにクラスタに割り振り、初期値とする
    rand.Seed(time.Now().UnixNano())
    clusterSize := len(knn.labels) / 10
    for i := 0; i < len(knn.labels); i++ {
        knn.clusterLabel = append(knn.clusterLabel, rand.Intn(clusterSize))
    }
    //クラスタあたりのセントロイド点を計算
    for i := 0; i < clusterSize; i++ {
        var clusterGroup [][]float64
        for j, cluster := range knn.clusterLabel {
            if cluster == i {
                clusterGroup = append(clusterGroup, knn.data[j])
            }
        }
        transposedClusterGroup := Transpose(clusterGroup)
        clusterCentroid := []float64{}
        for _, values := range transposedClusterGroup {
            //平均を求める
            mean := 0.0
            for _, value := range values {
                mean += value
            }
            clusterCentroid = append(clusterCentroid, mean/float64(len(values)))

        }
        knn.centroid = append(knn.centroid, clusterCentroid)
    }

    for {
        //代表ベクトルの更新
        //ラベルに属するデータ点の平均値を代表ベクトルの更新値とする
        //インデックスiはラベルを表す
        var tempRepresentatives [][]float64
        for i, _ := range knn.centroid {
            var grouped [][]float64
            for j, d := range knn.data {
                if knn.clusterLabel[j] == i {
                    grouped = append(grouped, d)
                }
            }
            if len(grouped) != 0 {

                transposedGroup := Transpose(grouped)
                updated := []float64{}
                for _, vectors := range transposedGroup {

                    value := 0.0
                    for _, v := range vectors {
                        value += v
                    }
                    //特徴量ごとの平均を格納
                    updated = append(updated, value/float64(len(vectors)))
                }
                tempRepresentatives = append(tempRepresentatives, updated)
            }
        }
        knn.centroid = tempRepresentatives

        //ラベル更新
        tempLabel := []int{}
        for _, d := range knn.data {
            var distance []float64
            for _, r := range knn.centroid {
                distance = append(distance, Dist(d, r))
            }
            tempLabel = append(tempLabel, ArgMin(distance))
        }
        if reflect.DeepEqual(knn.clusterLabel, tempLabel) {
            break
        } else {
            knn.clusterLabel = tempLabel
        }
    }
}

func (knn *KNN) predict(X [][]float64) []string {

    predictedLabel := []string{}
    for _, source := range X {
        var (
            centDistList []float64
            distList     []float64
            nearLabels   []string
        )
        //最も近いセントロイド点を探す
        //予測対象データとセントロイド点との距離を計算
        for _, dest := range knn.centroid {
            centDistList = append(centDistList, Dist(source, dest))
        }
        //距離の最も近いセントロイド点のラベルを獲得
        centS := NewFloat64Slice(centDistList)
        sort.Sort(centS)
        centroidTargetLabel := centS.idx[0]

        //予測対象データと教師データとの距離を計算
        var tempLabel []string
        for i, dest := range knn.data {
            if knn.clusterLabel[i] == centroidTargetLabel {
                distList = append(distList, Dist(source, dest))
                tempLabel = append(tempLabel, knn.labels[i])
            }
        }
        //距離の近い上位k個のインデックスを獲得
        s := NewFloat64Slice(distList)
        sort.Sort(s)
        targetIndex := s.idx[:knn.k]

        //獲得したインデックスのラベルを獲得
        for _, ind := range targetIndex {
            nearLabels = append(nearLabels, tempLabel[ind])
        }

        //ラベルの出現頻度を獲得
        labelFreq := Counter(nearLabels)

        //最も出現回数の多いラベルが予測対象データの予測ラベル
        a := List{}
        for k, v := range labelFreq {
            e := Entry{k, v}
            a = append(a, e)
        }
        sort.Sort(a)
        predictedLabel = append(predictedLabel, a[0].name)
    }
    return predictedLabel

}

精度比較

ナイーブなkNNで予測を行うとかかった精度と時間は以下の通り。

  • accuracy :0.96
  • time:17.10s

上記のkmeans的要素を取り込んだkNNの精度と時間は以下の通り。

  • accuracy:0.96
  • time:2.46s

このデータの場合は正答率はそのままに所用時間が激減した。

まとめ

本や記事に載っているようなアルゴリズムをそのまま実装したものを手を加えて予測所要時間を短くすることができた。面白かった。
データとアルゴリズムによっては数手入れると精度や所要時間を改良することができる。モデルの目的によってはよく知られた既存のアルゴリズムよりも手を加えた方が、より、目的にあった物になることもある。
機械学習アルゴリズムの勉強にもなるし色々いじってみるのも楽しい。
とはいえ、気をつけなればいけないことはいくつかある。
まず、手の加え方やデータによっては、それによってそれ以前なら保たれていた精度が落ちることがある。手を加える時にはきちんと検証をするべき。また、機械学習のパッケージ、ライブラリは優秀だ。幾分工夫しても、目的の焦点が絞られていても、案外既存のライブラリ以上に目的を達することができる物を作るのは難しい。言語が違うので一概には言えないが、kNNの場合はsklearnだと今回と同じデータも時間を測るのが阿呆らしいくらい瞬殺だった。
本やアルゴリズムに関する記事だけではなく、ソースコードを読んでいってもう少し工夫を取り込んでいきたいところ。

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5