ソースは gist にもある。
目的
- TensorFlowは線形代数APIを多く備えていてニューラルネットワーク専用フレームワークではないということを確認する。
- TensorFlow初心者であるためお手本なしでグラフを組む練習
作ったもの
- 複数回試行などを実装しないシンプルなk-means
- 重心の異なるサンプルを生成してクラスタリングする練習コード
- メジャーなライブラリの実装のような工夫をしていないので綺麗にクラスタが分かれることはあまりない
わかったこと、困ったこと
- tf.nn 系のチュートリアルだとVariableの更新は自動で行われるが tf.nn から離れると自前で行わないといけない(今回だと重心の更新)
-
Tensor.assin()
で更新できるとドキュメントにはあったがうまく動かなかった。今回はイテレーションごとにplaceholder
に値を渡した。 - map関数的なオペレーションがないため余計なメモリを食うコードになった。
distance_map
あたり。「n次元目の要素ごとにfを適用する」というようなことを綺麗に書くAPIが増えてほしい。 - 要素ごとに何かする関係で
reshape
を多用することになった。行列演算では当たり前なのかもしれないがパズルのようで面白かった。
質問、指摘、意見、アドバイスはコメント欄からよろしくお願いします。
import tensorflow as tf
import numpy as np
import sys
# simple k-means clustering with tensorflow
# this is my practice code for tensorflow graph coding
# parameters
K_SIZE = 5
SAMPLE_SIZE = 10000
SAMPLES_COV = [[5,0],[0,5]]
SAMPLES_MEAN_PARAM = 100
# create samples and initial centers
sample_list = []
for i in range(K_SIZE):
random_center = (SAMPLES_MEAN_PARAM * np.random.random_sample(2) - SAMPLES_MEAN_PARAM / 2).astype(np.float32)
sample_list.append(np.random.multivariate_normal(random_center, SAMPLES_COV, [SAMPLE_SIZE, 2]).astype(np.float32))
samples_data = np.reshape(sample_list,[-1,2])
np.random.shuffle(samples_data)
center_list = samples_data[0:5]
centers_data = np.reshape(center_list,[-1,2])
# make tensorflow graph
samples = tf.Variable(samples_data)
centers = tf.placeholder(tf.float32)
centers_map = tf.tile(centers, [samples_data.shape[0],1])
centers_map = tf.reshape(centers_map,[samples_data.shape[0] * K_SIZE, 2])
samples_map = tf.tile(samples, [1,K_SIZE])
samples_map = tf.reshape(samples_map,[samples_data.shape[0] * K_SIZE, 2])
distance_map = tf.sub(centers_map, samples_map)
distance_map = tf.pow(distance_map, [2])
distance_map = tf.reduce_sum(distance_map, 1)
distance_map = tf.reshape(distance_map, [-1, K_SIZE])
labels = tf.to_int32(tf.argmin(distance_map, 1))
# centers update iteration
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
updated_samples_list = tf.dynamic_partition(samples, labels, K_SIZE)
updated_centers = tf.concat(0,[tf.reduce_mean(t,0) for t in updated_samples_list])
updated_centers = tf.reshape(updated_centers, [-1,2])
for i in range(100):
centers_data = sess.run(updated_centers, feed_dict={centers:centers_data})
labeled_samples = sess.run(updated_samples_list, feed_dict={centers: centers_data})
# output clustered samples with gnuplot points format
index = 0
command = 'plot '
for s in labeled_samples:
if not len(s):
continue
command = command + ' "out.txt" using 1:2 index {0} with points,'.format(index)
index = index + 1
for p in s:
print('{0} {1}'.format(p[0],p[1]))
print('')
print('')
# output plot command to stderr
print(command[0:-1], file=sys.stderr)
sess.close()