3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

深層学習ライブラリSonnetを使ってMNISTを識別

Last updated at Posted at 2017-04-21

Sonnetとは

Sonnetは、DeepMind社が公開したTensorFlowベースの深層学習ライブラリです。
2017年4月7日に公開されました。
Sonnetは現状Python2.7系のみサポートしています。

筆者の開発環境

わたしの環境は以下の通りです。
環境が異なる方は適宜読み替えて対応してください。

  • macOS: Sierra 10.12.4
  • Python: 2.7.9 (pyenv)
  • Sonnet: 1.0
  • TensorFlow: 1.0.1

インストール

はじめにBazelというビルドツールを導入します。

Installing Bazel

次に、TensorFlowをインストールします。
ここでは、virtualenvで作った仮想環境上で環境を構築する方法を書きます。

$ mkdir sonnet
$ cd sonnet
$ pip install virtualenv
$ virtualenv venv
$ source ./venv/bin/activate
(venv)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.0.1-py2-none-any.whl

Sonnetのソースコードをcloneします。

(venv) $ git clone --recursive https://github.com/deepmind/sonnet

はじめにGPU周りの設定などを行うためのconfigureというスクリプトを実行します。

(venv) $ cd sonnet/tensorflow
(venv) $ ./configure
(venv) $ cd ../

Sonnetをビルドします。

(venv) $ mkdir /tmp/sonnet
(venv) $ bazel build —config=opt :install
(venv) $ ./bazel-bin/install /tmp/sonnet

最後に作られたwheelをpip installします。

(venv) $ pip install /tmp/sonnet/*.whl

Sonnetのインストールに関する日本語記事がすでにあるので、
うまく行かなかった人はこちらを参考にしても良いかもしれません。
Mac OSにSonnetをインストール

MNIST Classification

MNISTとは、0から9の数字の手書き文字データセットで、
このデータセットの識別は機械学習における最も有名な入門課題の1つです。
本記事では、MNISTの分類をSonnetを用いて実装する方法を説明します。
ソースコードはGithub上で公開しています。
kiyomaro927/sonnet_mnist

ライブラリのインポート

依存ライブラリをインポートします。
インストールに失敗していれば、この時点でこけるはずです。

import sonnet as snt
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

ハイパーパラメータの設定

学習率などのパラメータを設定します。
今回はTensorFlowのFLAGSを利用して定義します。

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer("num_training_iterations", 1000,
                        "Number of iterations to train for.")
tf.flags.DEFINE_integer("report_interval", 50,
                        "Iterations between reports (samples, valid loss).")
tf.flags.DEFINE_integer("batch_size", 64, "Batch size for training.")
tf.flags.DEFINE_integer("num_hidden", 128, "Size of MLP hidden layer.")
tf.flags.DEFINE_integer("output_size", 10, "Size of MLP output layer.")

分類モデルの定義

Sonnetで分類モデルを定義します。
Sonnetでは、snt.AbstractModuleを継承したクラスを作ることで、独自のモジュールを定義します。

class MLP(snt.AbstractModule):
    """MLP model, for use on MNIST dataset."""

    def __init__(self, num_hidden, output_size,
                 nonlinearity=tf.sigmoid, name='mlp'):
        """Construct a `MLP`.

        Args:
            num_hidden: Number of hidden units in first FC layer.
            output_size: Size of the output layer on top of the MLP.
            nonlinearity: Activation function.
            name: Name of the module.
        """

        super(MLP, self).__init__(name=name)

        self._num_hidden = num_hidden
        self._output_size = output_size
        self._nonlinearity = nonlinearity

        with self._enter_variable_scope():
            self._l1 = snt.Linear(output_size=self._num_hidden, name='l1')
            self._l2 = snt.Linear(output_size=self._output_size, name='l2')

    def _build(self, inputs):
        """Builds the MLP model sub-graph.

        Args
            inputs: A Tensor with the input MNIST data encoded as a
            784-dimensional representation. Its dimensions should be
            `[batch_size, 784]`.

        Returns:
            A Tensor with the prediction of given MNIST data encoded as a
            10-dimensional representation. Its dimensions should be
            `[batch_size, 10]`.
        """

        l1 = self._l1
        h = self._nonlinearity(l1(inputs))

        l2 = self._l2
        outputs = tf.nn.softmax(l2(h))

        return outputs

重要な部分をかいつまんで説明します。
はじめに__init__()内で、スーパークラスのコンストラクタにモジュール名を渡します。

super(MLP, self).__init__(name=name)

__init__内では、クラスのメンバ変数を定義しますが、
中にはニューラルネットワークの構成も含めて、
必要な変数は全てコンストラクタの中で定義したい方もいるかと思います。

最適化の対象になるパラメータを持つ変数をコンストラクタの中で定義する時は、
self._enter_variable_scope()にネストして定義します。

with self._enter_variable_scope():
    self._l1 = snt.Linear(output_size=self._num_hidden, name='l1')
    self._l2 = snt.Linear(output_size=self._output_size, name='l2')

次に_build()を定義します。
この関数は、このモジュールが計算グラフに接続されるたびに呼ばれます。
この関数内で定義されるVariable(ネットワークの重みやバイアスなど)は共有変数として機能します。
TensorFlowにおける共有変数の考え方は以下の記事によくまとまっています。
TensorFlow の名前空間を理解して共有変数を使いこなす

def _build(self, inputs):
    """Builds the MLP model sub-graph.

    Args
        inputs: A Tensor with the input MNIST data encoded as a
        784-dimensional representation. Its dimensions should be
        `[batch_size, 784]`.

    Returns:
        A Tensor with the prediction of given MNIST data encoded as a
        10-dimensional representation. Its dimensions should be
        `[batch_size, 10]`.
    """

    l1 = self._l1
    h = self._nonlinearity(l1(inputs))

    l2 = self._l2
    outputs = tf.nn.softmax(l2(h))

    return outputs

データセットモデルの定義

TensorFlowでは入力や出力に仮の変数としてtf.placeholderを定義しますが、
Sonnetのモジュールの計算グラフに接続されるたびに_buildが呼ばれるという仕組みを利用してデータセットもSonnetのモジュールとして定義すると、すっきりしたコードが書けます。
公式のチュートリアルに習ってcost()を定義しています。

class MNIST(snt.AbstractModule):
    """MNIST dataset model."""

    def __init__(self, mnist, batch_size, name='mnist'):
        """Construct a `MNIST`.

        Args:
            mnist: Dataset class object which has MNIST data.
            batch_size: Size of the output layer on top of the MLP.
            nonlinearity: Activation function.
            name: Name of the module.
        """

        super(MNIST, self).__init__(name=name)

        self._num_examples = mnist.num_examples
        self._images = tf.constant(mnist.images, dtype=tf.float32)
        self._labels = tf.constant(mnist.labels, dtype=tf.float32)
        self._batch_size = batch_size

    def _build(self):
        """Returns MNIST images and corresponding labels."""
        indices = tf.random_uniform([self._batch_size],
                                    0, self._num_examples, tf.int64)
        x = tf.gather(self._images, indices)
        y_ = tf.gather(self._labels, indices)
        return x, y_

    def cost(self, logits, target):
        """Returns cost.

        Args:
            logits: Model output.
            target: Correct labels.

        Returns:
            Cross-entropy loss for given outputs.
        """

        return -tf.reduce_sum(target * tf.log(logits))

学習フローの定義

データセットと識別モデルをSonnetのモジュールとして定義したので、それらのインスタンスを作成し、学習のフローを定義します。

mnist = read_data_sets('./MNIST_data', one_hot=True)
dataset_train = MNIST(mnist.train, batch_size=FLAGS.batch_size)
dataset_validation = MNIST(mnist.validation, batch_size=FLAGS.batch_size)
dataset_test = MNIST(mnist.test, batch_size=FLAGS.batch_size)

model = MLP(num_hidden=FLAGS.num_hidden, output_size=FLAGS.output_size)

# Build the training model and get the training loss.
train_x, train_y_ = dataset_train()
train_y = model(train_x)
train_loss = dataset_train.cost(train_y, train_y_)

# Get the validation loss.
validation_x, validation_y_ = dataset_validation()
validation_y = model(validation_x)
validation_loss = dataset_validation.cost(validation_y, validation_y_)

# Get the test loss.
test_x, test_y_ = dataset_test()
test_y = model(test_x)
test_loss = dataset_test.cost(test_y, test_y_)

# Set up optimizer.
train_step = tf.train.AdamOptimizer().minimize(train_loss)

学習ループの実行

最後に学習ループを回します。
TensorFlowでは定義した計算グラフをSessionオブジェクトに渡して計算を行います。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for training_iteration in range(FLAGS.num_training_iterations):
        if (training_iteration + 1) % report_interval == 0:
            train_loss_v, validation_loss_v, _ = sess.run(
                (train_loss, validation_loss, train_step))

            tf.logging.info("%d: Training loss %f. Validation loss %f.",
                            training_iteration,
                            train_loss_v,
                            validation_loss_v)
        else:
            train_loss_v, _ = sess.run((train_loss, train_step))
            tf.logging.info("%d: Training loss %f.",
                            training_iteration,
                            train_loss_v)

    test_loss = sess.run(test_loss)
    tf.logging.info("Test loss %f", test_loss)

おわりに

SonnetでMNISTを分類するサンプルについて説明しました。
MNISTを実装した程度で使いやすさについてとやかく言うつもりはありませんが、
少なくとも現状、ChainerやKerasなどの抽象度の高い深層学習フレームワークを使っていて特に窮屈さを感じていないなら乗り換えを検討する必要はないかと思います。

一方で、TensorFlowで全てのコードを書いている(かつPython2.7系を使っている)のであれば、Sonnetの導入はかなり容易であり可読性の高いコードを書く助けになるかと思います。

参考になれば幸いです。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?