Help us understand the problem. What is going on with this article?

DeepMindのSonnetを触ったので、TensorFlowやKerasと比較しながら解説してみた

More than 3 years have passed since last update.

はじめに

GoogleのDeepMindが社内的に使っている深層学習ライブラリSonnetを公開したので、試しに触ってみました。TensorFlowやKerasと比較をしながら解説していきたいと思います。まだ触ってみたばかりで細かい内容をお話することは出来ませんが、少しでもSonnetが気になっている方の力になれれば嬉しいです。
間違ってる部分や質問などありましたらコメントして頂けると幸いです。

Sonnet

SonnetはGoogleのDeepMindが社内的に使っている深層学習ライブラリです。TensorFlowのラッパーライブラリになっており、TensorFlowで書くのがちょっと面倒な部分にSonnetを使うことで今までよりも楽に速くコードを書くことが出来ます。現段階ではPython2.7にしか対応していません。

KerasもTensorFlowのラッパーライブラリですが、Kerasを用いる際は基本的にKerasの関数のみでコードを書いていました。SonnetはTensorFlowとミックスしてコードを書く点などで違ってきます。Sonnetは以下のGitHubのページにて公開されています。今後はこちらのページを参考に進めていきます。

※今回は分かりやすくするために、出来るだけコードの重要な部分のみ記述しています。

ネットワークの定義

Sonnetは、基本的にTensorFlowの関数と組み合わせてネットワークを定義します。TensorFlowでは重みやバイアスの定義や行列計算の部分などを書く必要がありましたが、Sonnetではその必要は無くなります。今までTensorFlowを書いていた方にとってはとても簡潔なコードに見えると思います。ここでは、よく用いられる層である、全結合層、畳み込み層などを用いて定義したネットワークを比較していきます。

ネットワークの構成

  • 畳み込み層
  • 正規化(BatchNormalization)
  • 活性化(Relu)
  • 平坦化(Flatten)
  • 全結合層

Sonnetでのネットワークの定義は以下のようになります。

import sonnet as snt

conv = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)
bn = snt.BatchNorm()
bf = snt.BatchFlatten()
linear = snt.Linear(output_size=10)
model = snt.Sequential([conv, bn, tf.nn.relu, bf, linear])

または

import sonnet as snt

def build(inputs):
    outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(inputs)
    outputs = snt.BatchNorm()(outputs)
    outputs = tf.nn.relu(outputs)
    outputs = snt.BatchFlatten()(outputs)
    outputs = snt.Linear(output_size=10)(outputs)
    return outputs
model = snt.Module(build)

上の書き方は、順番は関係なく各層を定義し、snt.Sequential()で活性化関数を含めて1つのネットワークとして定義しています。下の書き方はTensorFlowに似た書き方だと思います。各層を前の層の出力を入力とした関数として定義しています。

TensorFlowに慣れている方は、下の書き方の方が扱いやすいのではないかなと思います。最終的にはどちらの書き方をしても同じものとして扱うことができるので、お好きな書き方でお書きください。

SonnetはTensorFlowと互換性があり、ネットワークの定義の際もTensorFlowの関数とミックスして定義することが出来ます。例えば、下の書き方で最後の全結合層outputs = snt.Linear(output_size=10)(outputs)をTensorFlowの関数として定義した場合、以下のようになります。

import sonnet as snt

def build(inputs):
    outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(inputs)
    outputs = snt.BatchNorm()(outputs)
    outputs = tf.nn.relu(outputs)
    outputs = snt.BatchFlatten()(outputs)
    # TensorFlow
    weight = tf.Variable(tf.truncated_normal(shape=[4096,10]))
    bias = tf.Variable(tf.constant(0.1, shape=[1,10]))
    outputs = tf.matmul(outputs, weight) + bias
    outputs = tf.nn.relu(outputs)
    return outputs
model = snt.Module(build)

このようにSonnetとTensorFlowは自由に組み合わせて使うことが出来ます。基本的にはSonnetで書いて、細かく設定したい部分のみTensorFlowで書くということが出来るのでとても便利ですね。Sonnetはこのように抽象度を自由に変えれるので、「TensorFlowで書くのはちょっと大変だけど、Kerasは細かくいじりにくいなぁ」と悩んでいる方にはとても良いライブラリなのではないかなと思います。また、必要なコードの量が減るので、TensorFlowを使っていた方は今までよりも楽に高速にコードを書くことが出来るようになると思います。

Sonnetの誤差関数や最適化手法などの設定は、TensorFlowと同じように記述します。

outputs = model(inputs)
loss = tf.nn.l2_loss(targets - outputs)
train_step = tf.train.AdamOptimizer().minimize(loss)

また、Kerasでは以下のように記述します。

model = Sequential()

model.add(Convolution2D(nb_filters, nb_conv, nb_conv, 
                        border_mode='valid',
                        input_shape=(1,128,128)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(10))

model.compile(loss='mean_squared_error',
              optimizer=Adam())

Kerasもすっきりとしたわかりやすいコードで記述することが出来ます。しかし、Sonnetに比べるとTensorFlowと組み合わせる際の柔軟性が低く、抽象度が高くなってしまうため、細かい設定をしにくいというデメリットがあります(Kerasユーザの方すみません)。これがSonnetとKerasの一番大きな違いだと思います。ただ、KerasはTensorFlowやSonnetよりも簡単に記述できるので、細かいことをしたりしない場合やディープラーニングをちょっと触ってみたいなって方にはKerasをおすすめします。

学習フェーズ

Sonnetの学習フェーズはTensorFlowと同じように記述します。

with tf.Session() as sess:
    for epoch in range(nb_epoch):
        for i in range(len(input_size)/batch_size):
            sess.run(train_step, feed_dict=feed_dict)

Kerasでは以下のように記述します。

model.fit(X_train, Y_train,
          batch_size=batch_size,
          nb_epoch=nb_epoch)

SonnetはTensorflowと同じく、各エポックの各ステップでバッチサイズと同じ大きさの入力データと教師データを関数に入力して学習します。また、ネットワークに入力し出力した値を得るフェーズと、出力した値と教師データから誤差を求め学習するフェーズに分けることが出来ます。Kerasでは基本的に一度にすべての入力データと教師データを同時に入力して学習します。教師あり学習の際にはKerasはとても便利なのですが、強化学習など、出力してからじゃないと教師データを作成できない場合などはSonnetなどの方が書きやすいのかなと思います。

まとめ

以上のように、Sonnetを用いることで抽象度を自由に調節できるようになり、TensorFlowで書いていたコードをTensorFlowベースのまま高速に記述することが可能になります。Sonnetは、TensorFlowの汎用性とKerasの書きやすさの両方を兼ね備えたとても良いライブラリなのではないかなと思います。唯一のデメリットはPython2にしか対応していないという部分だと思います。今後のPython3対応に期待ですね!

ここまで読んでいただきありがとうございました!コードの詳しい説明をしていなかったり、分かりにくい部分もあるとは思いますが、少しでもSonnetに興味を持ってくれた方がいればとても嬉しいです!私もまだ全然使いこなせていないので、これからもっと触って行きたいなと思います。記事ではKerasのことを少し悪く書いたりしてしまいましたが、決してKerasが嫌いなわけではありません。用途や好みに合ったライブラリを選ぶ際の参考に少しでもなればとても嬉しいです。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした