LoginSignup
7
7

More than 5 years have passed since last update.

Tensorflowのkmeansを使う

Last updated at Posted at 2018-04-01

はじめに

scikit の kmeans は euclidean distance はサポートするけれど、cosine distance はサポートしないようです。そこで、cosine distanceをサポートするものを探したところ、tensorflow に含まれる kmeans で可能でした。
tensorflow のkmeans を試そうとしたところ、いくつかつまずいた点があったので、まとめてみました。
なお、この記事は、2018年4月1日に書いていますが、その後に問題点が解決されている可能性もあります。

環境

以下の環境を使っています。
Cent OS: 7.4.1708
Python: 3.6.4
tensorflow: r1.7

kmeans は、tf.contrib.factorization.KMeansClustering というクラスを使いました。

tensorflow で kmeans を使うまでに解決した点

Illegal instruction 問題

これは、kmeans ではなくて、tensorflow r1.6, 1.7 の問題です。

import tensorflow as tf

上記のコードを実行すると、illegal Instruction で停止します。
以下のリンクに記述されている問題のようです。
https://github.com/tensorflow/tensorflow/issues/17411
もっとも、簡単なワークアラウンドは、r1.5 を使用することのようです。
r1.6 から Intel の AVX 命令を使うようになり、それをサポートしていない CPU だと illegal instruction になるようです。
多分、コンパイラの問題なので、自分の環境でビルドすれば大丈夫かと思い、ビルドしなおしたところ、illegal instruction は出なくなりました。ビルドは以下のリンクを参考に行いました。
Installing TensorFlow from Sources

Dataset と input_fn について

KMeansClustering Class は Estimator Class の子 class になります。
Premade Estimator の使用方法は、Premade Estimators for ML Beginners に記述されています。
これに従うと、Estimator に供給するデータは、tf.data.Dataset class のオブジェクトとして渡すように説明されています。
一方で、Estimator の method に Dataset を渡す場合は、そのための関数 input_fn を定義して、それを method に渡す必要があります。直接、Dataset を渡すわけでないので、注意が必要です。

また、r1.7 の KMeansClustering のドキュメントには、使用例が書いてあって、そこでは、Dataset を使わない方法でデータを供給していました。試したところ、どちらの方法であっても、動作しました。

Dataset class のほうがよりハイレベルで、面倒な処理を隠蔽してくれるようなので、動作するのであれば、Dataset class を使ったほうがいいと思います。

以下のデータを読み込む input_fn を考えてみます。9x2 の行列です。

import numpy as np
data_x = np.array([100.0, 110.0, 120.0, 150.0, 155.0, 150.0, 178.0, 180.0, 900.0],
                  dtype = 'float32')
data_y = np.array([20.0, 25.0, 30.0, 48.0, 45.0, 50.0, 78.0, 75.0, 180.0],
                  dtype = 'float32')
raw_data = np.c_[data_x, data_y]

Dataset を使った場合のコード

tf.data.Dataset.from_tensor_slices() を使って、行列を Dataset に変換します。ここでの注意点は、Dataset ではなく、batch() という method を call して、データを Batch で取り出せるオブジェクトを使うことです。batch_size で指定した個数のデータを一回の処理で取り出すことができます。
また、input_fn の戻り値としては、Dataset ではなくて、batch でデータを取り出すための method を返します。

import tensorflow as tf
def input_fn():
    dataset = tf.data.Dataset.from_tensor_slices(raw_data).batch(batch_size=10)
    return dataset.make_one_shot_iterator().get_next()

実際には、Estimator の中では、以下のような操作をして、データを取り出すと推測されます。

sess = tf.Session()
get_next = dataset.make_one_shot_iterator().get_next()
data = sess.run(get_next)

KMeansClustering の method である train, score などが入力データを得るために、input_fn を引数として必要とします。ここで、必要となるのは、input_fn という関数のポインタです。そのため、kmeans.train(input_fn()) ではなくて、kmeans.train(input_fn) のように引数を指定します。
この例では、raw_data などは関数内に直接書き込んでいるので、input_fn 自身は引数を必要としませんが、input_fn(raw_data) のように、入力データなどを引数で渡したい場合があるかと思います。その場合、kmeans.train(input_fn(raw_data)) などとしてしまうと、関数ではなくて、Iterator が返ってしまうので注意が必要です。そういうケースでは、input_fn を関数ポインタを返すように書き換えるか、以下のように lambda で包んで、関数ポインタとなるようにします。

kmeans.train(lambda:input_fn(raw_data))

Dataset class には、batch 以外にも、 shuffle, repeat といったような method もあり、機械学習向けにデータ加工を行うことができます。

tf.convert_to_tensor を使った場合のコード

tf.convert_to_tensor を使うと、データをtensorflow の tensor 形式にしてくれます。KMeansClustering のドキュメント の Example で使われている方法です。

import tensorflow as tf
def input_fn():
    tensor =  tf.convert_to_tensor(raw_data, dtype=tf.float32)
    return tf.train.limit_epochs(tensor, num_epochs=1)

この場合も、tensor そのものを戻り値にするのではなく、tensor を取り出す method を指定します。tf.train.limit_epochs() は、データを取り出せる回数を num_epochs で指定します。この例では、一回取り出すことができれば十分なので、num_epochs は1にしています。一回の操作で tensor 全体が取り出されます。

tf.contrib.factorization.KMeansClustering とtf.contrib.learn.KMeansClustering

同じクラス名が二つあるので、紛らわしいです。Web 検索で見つかる記事によっては、tf.contrib.learn.KMeansClustering について書かれているものがあるので、それを見ながら tf.contrib.factorization.KMeansClustering を使おうとする混乱します。tf.contrib.learn.KMeansClustering は scikit-learn に含まれる kmeans と互換のようです。ただ、r1.7 のドキュメントによるとこのクラスは廃止になったとあります。

 テストコード

大部分は、tf.contrib.factorization.KMeansClustering の Example をそのまま使っています。

違いは以下のような部分です。
* Dataset でデータを Estimator に渡しています。
* COSINE_DISTANCE を使っています。
* use_mini_batch=False にしています。
use_mini_batch を True にすると、少しずつデータを読み込んで incremental に処理をするようなことがドキュメントに書いてありましたが、このテストではデータ数もすくないので、mini_batch は使っていません。

コードの動作としては、以下のようになっています。
* tf.contrib.factorization.KMeansClustering() で Estimator kmeans を初期化しています。
* kmeans.train() で学習をします。実際には、kmeans 法で、各 cluster の中心を計算します。
* kmeans.train() を N回実行することで、cluster の中心を再計算し、より誤差の少ない中心を求めています。kmeans.score() で各入力データと、最近傍の cluster 中心との距離の二乗誤差の合計を得ることができます。
* kmeans.train() には、steps, max_steps という再計算を複数行わせるための引数もあります。それを使う場合には、dataset の方で、batch_sizeを正しいデータ数に合わせると共に、step 数に合わせて十分な回数データを供給できるようします。具体的には、repeat method を使用して、step 数に合わせて、同じデータを繰り返すようにします。
* kmeans.predict_cluster_index() で入力データに対する cluster index を得ることができます。

#!/bin/env python
import tensorflow as tf
import numpy as np

# Test Data
data_x = np.array([100.0, 110.0, 120.0, 150.0, 155.0, 150.0, 178.0, 180.0, 900.0],
                  dtype = 'float32')
data_y = np.array([20.0, 25.0, 30.0, 48.0, 45.0, 50.0, 78.0, 75.0, 180.0],
                  dtype = 'float32')
raw_data = np.c_[data_x, data_y]

# Input function for Estimator
def input_fn():
    dataset = tf.data.Dataset.from_tensor_slices(raw_data).batch(batch_size=10)
    return dataset.make_one_shot_iterator().get_next()

#### main ####
num_clusters = 5
kmeans = tf.contrib.factorization.KMeansClustering(
    num_clusters=num_clusters,
    distance_metric=tf.contrib.factorization.KMeansClustering.COSINE_DISTANCE,
    use_mini_batch=False)

# train
num_iterations = 5
previous_centers = None
for i in range(num_iterations):
    print('iteration: ', i)
    kmeans.train(input_fn)
    cluster_centers = kmeans.cluster_centers()
    if previous_centers is not None:
        print('  delta:', cluster_centers - previous_centers)
    previous_centers = cluster_centers
    print('  score:', kmeans.score(input_fn))
print('Final cluster centers:', cluster_centers)
print()

# map the input points to their clusters
cluster_indices = list(kmeans.predict_cluster_index(input_fn))
for i, point in enumerate(raw_data):
  cluster_index = cluster_indices[i]
  center = cluster_centers[cluster_index]
  print('point:', point, 'is in cluster', cluster_index, 'centered at',center)
7
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
7