Help us understand the problem. What is going on with this article?

TensorFlow用のインターフェースTFLearnが新しくなっていたので試してみた

More than 3 years have passed since last update.

TensorFlow0.9からTFLearnが独立して、使いやすくなっていそうだったので使ってみた

インストール

TensorFlow0.9以上が必要

$ pip install tflearn

上でもインストールできるけど、現時点(0.2.1)では公式のチュートリアルも実行できなかったので以下で最新版(0.2.2)をインストール

$ pip install git+https://github.com/tflearn/tflearn.git

チュートリアル

タイタニック号乗客の生存確率をMulti Layer Perceptronで解く問題が公式にある

TFLearn - Quick Start

モデルの定義と学習部分を抜き出したのが以下

titanic.py
# Build neural network
net = tflearn.input_data(shape=[None, 6])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)

# Define model
model = tflearn.DNN(net)
# Start training (apply gradient descent algorithm)
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

すっきりしていて書きやすい

VGG Net

vgg.py
# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    division,
    print_function
)

from six.moves import range

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression

net = input_data(shape=[None, width, height, channel])

for i in range(2):
    net = conv_2d(net, 64, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    net = conv_2d(net, 128, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(3):
    net = conv_2d(net, 256, 3, activation='relu')
net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    for j in range(3):
        net = conv_2d(net, 512, 3, activation='relu')
    net = max_pool_2d(net, 2, strides=2)

for i in range(2):
    net = fully_connected(net, 4096, activation='relu')
    net = dropout(net, 0.5)

net = fully_connected(net, n_classes, activation='softmax')

net = regression(net, optimizer='adam', loss='categorical_crossentropy')

model = tflearn.DNN(net, checkpoint_path='model_vgg', max_checkpoints=1, tensorboard_verbose=0)

model.fit(X, y, n_epoch=500, shuffle=True, show_metric=True, batch_size=32)

元のサンプルをfor文を使って書き直したけどかえって構造がわかりにくい気もするので、そこはfor文を使わずにコピペするなどすればいい
slimのrepeatやarg_scopeが使えるようになると、もっとすっきりすると思う
slimは別に開発されているけど、TFLearnに統合されそうな気がする

サンプル

公式ドキュメントにNetwork in Network, VGG, Residual Network, LSTM, Seq2seq, Q-learning(パックマン)と色々なサンプルがある

TFLearn Examples

TFLearnの利点

  • TensorFlow0.8くらいときに比べて、非常に書きやすくなっている
  • Convolution 2D Transpose, Residual Block / Residual Bottleneckなどの様々なレイヤーが標準装備
  • 生のTensorFlowとも親和性高い

感想

まだドキュメントにも色々とToDoが残ってるし、特に他にない機能を使いたいのでなければKerasなどを使っていてもいいかもしれない
TensorFlowの周辺は開発が早くて、気づいたら色々なものが出てきそう

shngt
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away