LoginSignup
19
16

More than 5 years have passed since last update.

TensorFlow用のインターフェースTFLearnが新しくなっていたので試してみた

Posted at

TensorFlow0.9からTFLearnが独立して、使いやすくなっていそうだったので使ってみた

インストール

TensorFlow0.9以上が必要

$ pip install tflearn

上でもインストールできるけど、現時点(0.2.1)では公式のチュートリアルも実行できなかったので以下で最新版(0.2.2)をインストール

$ pip install git+https://github.com/tflearn/tflearn.git

チュートリアル

タイタニック号乗客の生存確率をMulti Layer Perceptronで解く問題が公式にある

TFLearn - Quick Start

モデルの定義と学習部分を抜き出したのが以下

titanic.py
# Build neural network
net = tflearn.input_data(shape=[None, 6])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)

# Define model
model = tflearn.DNN(net)
# Start training (apply gradient descent algorithm)
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

すっきりしていて書きやすい

VGG Net

vgg.py
# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    division,
    print_function
)

from six.moves import range

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression

net = input_data(shape=[None, width, height, channel])

for i in range(2):
    net = conv_2d(net, 64, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    net = conv_2d(net, 128, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(3):
    net = conv_2d(net, 256, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    for j in range(3):
        net = conv_2d(net, 512, 3, activation='relu')
    net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    net = fully_connected(net, 4096, activation='relu')
    net = dropout(net, 0.5)

net = fully_connected(net, n_classes, activation='softmax')

net = regression(net, optimizer='adam', loss='categorical_crossentropy')

model = tflearn.DNN(net, checkpoint_path='model_vgg', max_checkpoints=1, tensorboard_verbose=0)

model.fit(X, y, n_epoch=500, shuffle=True, show_metric=True, batch_size=32)

元のサンプルをfor文を使って書き直したけどかえって構造がわかりにくい気もするので、そこはfor文を使わずにコピペするなどすればいい
slimのrepeatやarg_scopeが使えるようになると、もっとすっきりすると思う
slimは別に開発されているけど、TFLearnに統合されそうな気がする

サンプル

公式ドキュメントにNetwork in Network, VGG, Residual Network, LSTM, Seq2seq, Q-learning(パックマン)と色々なサンプルがある

TFLearn Examples

TFLearnの利点

  • TensorFlow0.8くらいときに比べて、非常に書きやすくなっている
  • Convolution 2D Transpose, Residual Block / Residual Bottleneckなどの様々なレイヤーが標準装備
  • 生のTensorFlowとも親和性高い

感想

まだドキュメントにも色々とToDoが残ってるし、特に他にない機能を使いたいのでなければKerasなどを使っていてもいいかもしれない
TensorFlowの周辺は開発が早くて、気づいたら色々なものが出てきそう

19
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
16