Help us understand the problem. What is going on with this article?

生のTensorFlowとtf.contrib.learnとKerasを比較してみよう

More than 3 years have passed since last update.

これはなに?

DevFest Tokyo 2016での 発表で出したサンプルコードの全体を見たいという要望を頂いたので晒しておきます。

やりたかったこと

生のTensorFlowとTensorFlowの高レベルAPI版tf.contrib.learnとTensorFlowをバックエンドにしてDSLっぽくネットワークを記述できるKerasで、同じデータ、同じ手法で揃えてどう違うか、と横串で見れるものがなかったので横並びで見えるようにしたかった。
と、いうのも、それぞれのチュートリアルを見ていくと、微妙にやっていることが違うため、混乱してしまう、というか私は混乱してしまいました。

自分が調べた範囲では同じような条件で比較しているものがなかったので、混乱する方が少しでも減るように記録を残しておこうかなと思った次第です。とりあえずtf.contrib.learnのチュートリアルのものを基準に他のものを揃える感じでいきます。

揃えた条件

  • iris datasetという統計分析とか機械学習とかでよく例として出て来る花(アヤメ)のデータを使う
    • tf.contrib.learn.datasets.base.load_iris()から拝借
  • 中間層3で、各層は10, 20, 30
  • 活性化関数はReLu
  • 出力層はsoftmax
  • 学習データ、テストデータの分割はsklearn.cross_validation()の力を借りる

揃えていないもの

  • オプティマイザ
  • ネットワークの初期値の設定とか

後はほかも色々揃ってない気もするけど大体揃ってる感じはある

やっていないこと

  • slimの対応

要望はありそうでしたが時間の都合でたどり着いていないです。

比較

それでは早速比較です

生TensorFlow版

  • 全ての工程を自前でやっているので、ちゃんとやってることを理解するには良さそう
  • mnistのチュートリアルのコードはそれ用のヘルパー関数がありすぎて、TensorFlow自体が提供してくれているものとの見分けがつかず、いざやろうとするとチュートリアル専用のものが多く結構悲しい
    • なのでそれっぽいヘルパー関数を適当に用意してみた
import tensorflow as tf
import numpy as np
from sklearn import cross_validation

# 1-hotベクトル生成用の関数
def one_hot_labels(labels):
    return np.array([
        np.where(labels == 0, [1], [0]),
        np.where(labels == 1, [1], [0]),
        np.where(labels == 2, [1], [0])
    ]).T

# 指定されたバッチサイズでデータをランダムに取得
def next_batch(data, label, batch_size):
    perm = np.arange(data.shape[0])
    np.random.shuffle(perm)
    return data[perm][:batch_size], label[perm][:batch_size]

# 学習データの準備
iris = tf.contrib.learn.datasets.base.load_iris()

train_x, test_x, train_y, test_y = cross_validation.train_test_split(
    iris.data, iris.target, test_size=0.2
)

# 入力層
x = tf.placeholder(tf.float32, [None, 4], name='input')

# 第1層
W1 = tf.Variable(tf.truncated_normal([4, 10], stddev=0.5, name='weight1'))
b1 = tf.Variable(tf.constant(0.0, shape=[10], name='bias1'))
h1 = tf.nn.relu(tf.matmul(x,W1) + b1)

# 第2層
W2 = tf.Variable(tf.truncated_normal([10, 20], stddev=0.5, name='weight2'))
b2 = tf.Variable(tf.constant(0.0, shape=[20], name='bias2'))
h2 = tf.nn.relu(tf.matmul(h1,W2) + b2)

# 第3層
W3 = tf.Variable(tf.truncated_normal([20, 10], stddev=0.5, name='weight3'))
b3 = tf.Variable(tf.constant(0.0, shape=[10], name='bias3'))
h3 = tf.nn.relu(tf.matmul(h2,W3) + b3)

# 出力層
W4 = tf.Variable(tf.truncated_normal([10, 3], stddev=0.5, name='weight4'))
b4 = tf.Variable(tf.constant(0.0, shape=[3], name='bias4'))
y = tf.nn.softmax(tf.matmul(h3,W4) + b4)

# 理想的な出力値
y_ = tf.placeholder(tf.float32, [None, 3], name='teacher_signal')

# 理想的な出力値との比較
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

# 学習処理と呼ばれているもの
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    for i in range(2000):
        # 学習処理
        batch_size = 100
        batch_train_x, batch_train_y = next_batch(train_x, train_y, batch_size)
        sess.run(train_step, feed_dict={x: batch_train_x, y_: one_hot_labels(batch_train_y)})

    # 学習結果の評価
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(sess.run(accuracy, feed_dict={x: test_x, y_: one_hot_labels(test_y)}))

tf.contrib.learn版

  • 生のTensorFlowとくらべて拍子抜け感があるくらい楽!
  • APIを読むとoptimizerdropoutやその他色々設定できそうなので、思ったよりも色々できそう
import tensorflow as tf
from sklearn import cross_validation

# 学習データの準備
iris = tf.contrib.learn.datasets.base.load_iris()
train_x, test_x, train_y, test_y = cross_validation.train_test_split(
    iris.data, iris.target, test_size=0.2
)

# 全ての特徴は実数値ですよと教える
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# 3層のDNN
# 特に何も指定しないと活性化関数はReLUを選んでくれるっぽい
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="./iris_model")
# モデルのフィッティング
classifier.fit(x=train_x,
               y=train_y,
               steps=2000,
               batch_size=50)

# 精度の評価
print(classifier.evaluate(x=test_x, y=test_y)["accuracy"])

Keras版

  • もう全然違うものですね
  • が、実験的にネットワークを構築していく場合は、生TensorFlowよりもはるかに見通しが良さそう
import tensorflow as tf
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from sklearn import cross_validation

# 入力データの準備
iris = tf.contrib.learn.datasets.base.load_iris()
train_x, test_x, train_y, test_y = cross_validation.train_test_split(
    iris.data, iris.target, test_size=0.2
)

# モデルの定義
model = Sequential()

# ネットワークの定義
model.add(Dense(input_dim=4, output_dim=10))
model.add(Activation('relu'))
model.add(Dense(input_dim=10, output_dim=20))
model.add(Activation('relu'))
model.add(Dense(input_dim=20, output_dim=10))
model.add(Activation('relu'))
model.add(Dense(output_dim=3))
model.add(Activation('softmax'))

# ネットワークのコンパイル
model.compile(loss = 'sparse_categorical_crossentropy',
              optimizer = 'sgd',
              metrics = ['accuracy'])

# 学習処理
model.fit(train_x, train_y, nb_epoch = 2000, batch_size = 100)

# 学習結果の評価
loss, metrics = model.evaluate(test_x, test_y)

勝手な考察

パット見た感じ、以下のような感じでしょうか。
- 実験的に取り組むならKeras
- やることが決まっていて、提供されているAPIでやりたい事が実現できるならtf.learn
- Kerasで作ったものがtf.learnでは再現できない時、あるいは、すでに実装したいものがイメージできている場合は生のTensorFlow(でもデバッグが大変かも)

やらない可能性が大いにあるTODO

  • もっと深掘り
  • slimも交えた比較
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした