Slimとは
TF-Slim is a lightweight library for defining, training and evaluating complex models in TensorFlow. Components of tf-slim can be freely mixed with native tensorflow, as well as other frameworks, such as tf.contrib.learn.
(意訳)
TF-SlimはTensorflowで複雑なモデルを定義・訓練・評価するための軽量ライブラリで、生のTensorflowやTF-Learn(tf.contrib.learn)と混ぜても安全。
Githubのリポジトリで開発が進んでいる。
r0.10のブランチからTensorflowをビルドしたら使えそうだけど、0.9.0でも一部機能は実装されていたので無理やり使ってみた。
題材はCIFAR10
インポート
from tensorflow.contrib import slim
データ読み込み
0.10だとslim.dataに色々なメソッドが実装されていそうだったけど、0.9ではslim.dataが存在しないので諦める。
モデル定義
conv*2, pool*1, norm*1のセットを2回→fc*3
with slim.arg_scope([slim.layers.conv2d, slim.layers.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
weights_regularizer=slim.regularizers.l2_regularizer(0.0005)):
with slim.arg_scope([slim.layers.max_pool2d], padding='SAME'):
net = slim.layers.repeat(inputs, 2, slim.layers.conv2d, 64, [3, 3], scope='conv1')
net = slim.layers.max_pool2d(net, [3, 3], scope='pool1')
net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
net = slim.layers.repeat(net, 2, slim.layers.conv2d, 64, [5, 5], scope='conv2')
net = slim.layers.max_pool2d(net, [3, 3], scope='pool2')
net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
net = slim.layers.flatten(net, scope='flatten')
net = slim.layers.stack(net, slim.layers.fully_connected, [384, 192], scope='fc1')
net = slim.layers.fully_connected(net, NUM_CLASSES, activation_fn=None, scope='fc2')
repeatは0.9に含まれていないけど、これだけはないと不便なので0.10からコピー。
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
def repeat(inputs, repetitions, layer, *args, **kwargs):
scope = kwargs.pop('scope', None)
with variable_scope.variable_op_scope([inputs], scope, 'Repeat'):
inputs = ops.convert_to_tensor(inputs)
if scope is None:
if hasattr(layer, '__name__'):
scope = layer.__name__
elif hasattr(layer, 'func') and hasattr(layer.func, '__name__'):
scope = layer.func.__name__ # In case layer is a functools.partial.
else:
scope = 'repeat'
outputs = inputs
for i in range(repetitions):
kwargs['scope'] = scope + '_' + str(i+1)
outputs = layer(outputs, *args, **kwargs)
return outputs
slim.layers.repeat = repeat
損失関数
loss = slim.losses.softmax_cross_entropy(predictions, labels)
最適化
optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
学習
train_op = slim.learning.create_train_op(loss, optimizer)
result = slim.learning.train(train_op,
FLAGS.log_dir,
number_of_steps=1000,
save_summaries_secs=300,
save_interval_secs=600)
評価
slim.evaluationに用意されているが、うまく使えなかったので割愛。
感想
READMEに書いてあることも信用できないし、0.9だと未実装なものや0.10ではメソッド名が変わっているものもあってコードを直接見ることになるので、素直に0.10を待ったほうがいいかもしれない。
ウェイト・バイアスのリストアとモデル評価のところがうまくいかなかったので、誰かできたら教えて欲しい……
モデルの定義だけは利用すると簡潔でいい。変数のリストア方法がわからなかったけど……
コード
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import slim
import os, sys, tarfile
from six.moves import urllib, range
from tensorflow.models.image.cifar10 import cifar10_input
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('tmp_data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_string('log_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
NUM_CLASSES = cifar10_input.NUM_CLASSES
LEARNING_RATE = 0.1
def main(argv=None): # pylint: disable=unused-argument
maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.log_dir):
tf.gfile.DeleteRecursively(FLAGS.log_dir)
tf.gfile.MakeDirs(FLAGS.log_dir)
with tf.Graph().as_default():
images, labels = distorted_inputs()
labels = slim.layers.one_hot_encoding(labels, NUM_CLASSES)
predictions = my_model(images)
loss = slim.losses.softmax_cross_entropy(predictions, labels)
tf.scalar_summary('loss', loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = slim.learning.create_train_op(loss, optimizer)
result = slim.learning.train(train_op,
FLAGS.log_dir,
number_of_steps=1000,
save_summaries_secs=300,
save_interval_secs=600)
print('loss: %f' % result)
def my_model(inputs):
with slim.arg_scope([slim.layers.conv2d, slim.layers.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
weights_regularizer=slim.regularizers.l2_regularizer(0.0005)):
with slim.arg_scope([slim.layers.max_pool2d], padding='SAME'):
net = slim.layers.repeat(inputs, 2, slim.layers.conv2d, 64, [3, 3], scope='conv1')
net = slim.layers.max_pool2d(net, [3, 3], scope='pool1')
net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
net = slim.layers.repeat(net, 2, slim.layers.conv2d, 64, [5, 5], scope='conv2')
net = slim.layers.max_pool2d(net, [3, 3], scope='pool2')
net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
net = slim.layers.flatten(net, scope='flatten')
net = slim.layers.stack(net, slim.layers.fully_connected, [384, 192], scope='fc1')
net = slim.layers.fully_connected(net, NUM_CLASSES, activation_fn=None, scope='fc2')
return net
def distorted_inputs():
if not FLAGS.tmp_data_dir:
raise ValueError('Please supply a tmp_data_dir')
data_dir = os.path.join(FLAGS.tmp_data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size)
return images, labels
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.tmp_data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
# define slim.layers.repeat
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
def repeat(inputs, repetitions, layer, *args, **kwargs):
scope = kwargs.pop('scope', None)
with variable_scope.variable_op_scope([inputs], scope, 'Repeat'):
inputs = ops.convert_to_tensor(inputs)
if scope is None:
if hasattr(layer, '__name__'):
scope = layer.__name__
elif hasattr(layer, 'func') and hasattr(layer.func, '__name__'):
scope = layer.func.__name__ # In case layer is a functools.partial.
else:
scope = 'repeat'
outputs = inputs
for i in range(repetitions):
kwargs['scope'] = scope + '_' + str(i+1)
outputs = layer(outputs, *args, **kwargs)
return outputs
slim.layers.repeat = repeat
if __name__ == '__main__':
tf.app.run()