29
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Tensorflowの便利ライブラリTF-Slim

Last updated at Posted at 2016-07-30

Tensorflowの便利ライブラリTF-Slim

以前、Tensorflowのラッパーライブラリについて書いた。

Keras(TensorFlowバックエンド)でMNIST
skflowでMNIST

Kerasは使いやすそうなもののDeconvolution(Transposed Convolution)など用意されていないものがあったので、(よく見てみたらUpsampling2DとConvolution2Dを使えばできる気がする)
skflowや生のTensorflowを勉強しようかとTutorialを眺めていたところ、隠された便利ライブラリを発見した。

TensorFlow-Slim

Githubのリポジトリだと0.9から追加されたみたい。

contribの中にはあまりアナウンスされていないけど良さげなものが含まれていたりするので、一度見ておくといいと思う。正式サポートしないし、将来的にどうなるかわからないよとは書いてあるけど。

tensorflow/tensorflow/contrib/

以下の内容は、今後大幅に変わる可能性もある。

TF-Slimのインポート

import tensorflow as tf
from tensorflow.contrib import slim

ウェイトの初期化

weights = slim.variables.variable('weights',
                             shape=[10, 10, 3 , 3],
                             initializer=tf.truncated_normal_initializer(stddev=0.1),
                             regularizer=slim.l2_regularizer(0.05),
                             device='/CPU:0')

これで動くと書いてあるが、variableというメソッドは現時点で実装されていない感じがするので、inceptionの中に入っているものを持ってきてインポートしないといけないかも。

models/inception/inception/

層の定義

(conv+pool)5, fc3のレイヤーは以下のように書ける
READMEに書いてある通り、VGG16でした。
conv2
pool
conv
2
pool
conv3
pool
conv
3
pool
conv3
pool
fc
3

with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
  net = slim.ops.repeat_op(2, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
  net = slim.ops.max_pool(net, [2, 2], scope='pool1')
  net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
  net = slim.ops.max_pool(net, [2, 2], scope='pool2')
  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
  net = slim.ops.max_pool(net, [2, 2], scope='pool3')
  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
  net = slim.ops.max_pool(net, [2, 2], scope='pool4')
  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
  net = slim.ops.max_pool(net, [2, 2], scope='pool5')
  net = slim.ops.flatten(net, scope='flatten5')
  net = slim.ops.fc(net, 4096, scope='fc6')
  net = slim.ops.dropout(net, 0.5, scope='dropout6')
  net = slim.ops.fc(net, 4096, scope='fc7')
  net = slim.ops.dropout(net, 0.5, scope='dropout7')
  net = slim.ops.fc(net, 1000, activation=None, scope='fc8')
return net

conv*3+poolをこんな感じで短縮して書くこともできる

net = ...
for i in range(3):
  net = slim.ops.conv2d(net, 256, [3, 3], scope='conv3_' % (i+1))
net = slim.ops.max_pool(net, [2, 2], scope='pool3')

さらに、slimに用意されたrepeatメソッドを使うと

net = slim.ops.repeat_op(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.ops.max_pool(net, [2, 2], scope='pool2')

これでちゃんとスコープを'conv3/conv3_1', 'conv3/conv3_2', 'conv3/conv3_3'と調整してくれるらしい。

fc*3も以下のものを

x = slim.ops.fc(x, 32, scope='fc/fc_1')
x = slim.ops.fc(x, 64, scope='fc/fc_2')
x = slim.ops.fc(x, 128, scope='fc/fc_3')

1行で書ける

slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc')

ということは当然convも

slim.stack(x, slim.ops.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')

でOKとcontribのREADMEに書いてあるが、実装はない。(InceptionのREADMEには書いていない)

スコープ

例えばこんなconv*3の層があるとして

padding = 'SAME'
initializer = tf.truncated_normal_initializer(stddev=0.01)
regularizer = slim.losses.l2_regularizer(0.0005)
net = slim.ops.conv2d(inputs, 64, [11, 11], 4,
                      padding=padding,
                      weights_initializer=initializer,
                      weights_regularizer=regularizer,
                      scope='conv1')
net = slim.ops.conv2d(net, 128, [11, 11],
                      padding='VALID',
                      weights_initializer=initializer,
                      weights_regularizer=regularizer,
                      scope='conv2')
net = slim.ops.conv2d(net, 256, [11, 11],
                      padding=padding,
                      weights_initializer=initializer,
                      weights_regularizer=regularizer,
                      scope='conv3')

slimに用意されたscopeを使うと引数が違う部分だけを記述し、残りを省略できる

with slim.arg_scope([slim.ops.conv2d], padding='SAME',
                    weights_initializer=tf.truncated_normal_initializer(stddev=0.01)
                    weights_regularizer=slim.losses.l2_regularizer(0.0005)):
  net = slim.ops.conv2d(inputs, 64, [11, 11], scope='conv1')
  net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')
  net = slim.ops.conv2d(net, 256, [11, 11], scope='conv3')

さらに、scopeを重ねて

with slim.arg_scope([slim.ops.conv2d, slim.ops.fc],
                    activation_fn=tf.nn.relu,
                    weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                    weights_regularizer=slim.losses.l2_regularizer(0.0005)):
with arg_scope([slim.ops.conv2d], stride=1, padding='SAME'):
  net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
  net = slim.ops.conv2d(net, 256, [5, 5],
                    weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
                    scope='conv2')
  net = slim.ops.fc(net, 1000, activation_fn=None, scope='fc')

convとfcに共通するものを定義した後、convだけに適用するものを定義できる。

損失関数

これでOK

loss = slim.losses.cross_entropy_loss(predictions, labels)

訓練

slim.learningはInceptionには見当たらず、contrib内のslimには存在。

g = tf.Graph()

# モデルと損失関数を定義
# ...

total_loss = tf.get_collection(slim.losses.LOSSES_COLLECTION)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)

train_op = slim.learning.create_train_op(total_loss, optimizer)
logdir = './stored_log/'

slim.learning.train(
    train_op,
    logdir,
    number_of_steps=1000,
    save_summaries_secs=300,
    save_interval_secs=600)

感想

実際に使えるようになると、かなり便利になるんじゃないかと感じる。
v0.10では普通に使えるようになっているとかなり嬉しい。

29
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?