TensorFlow0.9からTFLearnが独立して、使いやすくなっていそうだったので使ってみた
インストール
TensorFlow0.9以上が必要
$ pip install tflearn
上でもインストールできるけど、現時点(0.2.1)では公式のチュートリアルも実行できなかったので以下で最新版(0.2.2)をインストール
$ pip install git+https://github.com/tflearn/tflearn.git
チュートリアル
タイタニック号乗客の生存確率をMulti Layer Perceptronで解く問題が公式にある
モデルの定義と学習部分を抜き出したのが以下
titanic.py
# Build neural network
net = tflearn.input_data(shape=[None, 6])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)
# Define model
model = tflearn.DNN(net)
# Start training (apply gradient descent algorithm)
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
すっきりしていて書きやすい
VGG Net
vgg.py
# -*- coding: utf-8 -*-
from __future__ import (
absolute_import,
division,
print_function
)
from six.moves import range
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
net = input_data(shape=[None, width, height, channel])
for i in range(2):
net = conv_2d(net, 64, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)
for i in range(2):
net = conv_2d(net, 128, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)
for i in range(3):
net = conv_2d(net, 256, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)
for i in range(2):
for j in range(3):
net = conv_2d(net, 512, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)
for i in range(2):
net = fully_connected(net, 4096, activation='relu')
net = dropout(net, 0.5)
net = fully_connected(net, n_classes, activation='softmax')
net = regression(net, optimizer='adam', loss='categorical_crossentropy')
model = tflearn.DNN(net, checkpoint_path='model_vgg', max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, y, n_epoch=500, shuffle=True, show_metric=True, batch_size=32)
元のサンプルをfor文を使って書き直したけどかえって構造がわかりにくい気もするので、そこはfor文を使わずにコピペするなどすればいい
slimのrepeatやarg_scopeが使えるようになると、もっとすっきりすると思う
slimは別に開発されているけど、TFLearnに統合されそうな気がする
サンプル
公式ドキュメントにNetwork in Network, VGG, Residual Network, LSTM, Seq2seq, Q-learning(パックマン)と色々なサンプルがある
TFLearnの利点
- TensorFlow0.8くらいときに比べて、非常に書きやすくなっている
- Convolution 2D Transpose, Residual Block / Residual Bottleneckなどの様々なレイヤーが標準装備
- 生のTensorFlowとも親和性高い
感想
まだドキュメントにも色々とToDoが残ってるし、特に他にない機能を使いたいのでなければKerasなどを使っていてもいいかもしれない
TensorFlowの周辺は開発が早くて、気づいたら色々なものが出てきそう