8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlow2.0Advent Calendar 2019

Day 15

TensorFlow2.0の傾向と対策 ~ DeepMind Sonnet2で見る深層学習プログラミング

Last updated at Posted at 2019-12-14

前置き

これはTensorFlow2.0 Advent Calendar 2019の15日目のエントリです。また、「TensorFlow2.0の傾向と対策」シリーズの4つ目の記事でもあります。
この記事では、DeepMindのSonnetのTF2版のSonnet2を紹介し、深層学習ライブラリについての考察を添えます。
他記事に無い部分は以下の項目と思います。

  • Sonnet2
  • tf.Module
  • 提案「訓練シナリオの独立」
  • 深層学習用モジュールの分類

Sonnet2に興味のない方は、最後の方の「深層学習とモジュール化」の章まで飛ばすと良いかもしれません。

Sonnet2について

DeepMindが以前から公開していたSonnetというライブラリがあります。これをTF2化したのがSonnet2です。これは、深層学習のネットワークを記述するためのTensorFlowの薄いラッパーです。
2019年12月15日現在、Sonnetリポジトリのv2ブランチで開発中であり、ベータ版が公開されています。

Sonnet2の使い方

インストール

$ pip install tensorflow-gpu tensorflow-probability
$ pip install "dm-sonnet>=2.0.0b0" --pre

pipコマンドの違いやgpuの有無は環境に応じて行ってください。pipenvを使う場合は注意が必要です。(「XX-manylinux2010_x86_64.whlをpipenvで扱えなかった件とその暫定対処」を参考にしてください)。

基本API

Sonnet2を理解するには、snt.Modulesnt.onceだけ把握すれば良いです。

例えば、こんな感じです。

import tensorflow as tf
import sonnet as snt


class FC(snt.Module):
    def __init__(self, ch, name=None):
        super().__init__(name=name)
        self.ch = ch

    @snt.once
    def _initialize(self, x):
        ch_in = x.shape[-1]
        self.w = tf.Variable(tf.zeros([ch_in, self.ch]), dtype=tf.float32)
        self.b = tf.Variable(tf.zeros([self.ch]), dtype=tf.float32)

    def __call__(self, x):
        self._initialize(x)
        return tf.matmul(x, self.w) + self.b


x = tf.zeros([1, 7])
fc = FC(5)
print(fc(x))
print(fc.variables)
  • snt.Moduleを継承したクラスで__call__メソッドでネットワークを書く
  • __call__の先頭で変数を作る(__initialize
  • 変数の作成は一度でいいので、__initializesnt.onceデコレータをつける
  • 変数はvariablesプロパティで取得する(呼ばれるまで変数が作られないことがあるのに注意)

snt.Module

snt.Moduleはモデルや層を表すクラスです。役割は、変数の管理・ネットワーク(計算)の定義です。
Sonnet2にはsnt.Moduleで作られた層(畳み込み等)があり、それを使うと便利です。snt.Moduleの中にsnt.Moduleを入れれるので、全体等の分かりやすい単位のsnt.Moduleを作成することができます。このとき、内部のsnt.Moduleの変数は、全体のsnt.Moduleの変数一覧(variables)に含まれます。

class TotalModel(snt.Module):
    def __init__(self, name=None):
        super().__init___(self, name=name)
        self.sub1 = SubModel1()
        self.sub2 = SubModel1()
        self.sub3 = SubModel1()

    def __call__(self, x):
        return self.sub3(self.sub2(self.sub1(x)))

snt.once

層を定義するときに入力のチャンネル数を書くのってだるいですよね。ということで、その層に初めて入力が来たときに入力のチャンネル数を見て決定するインターフェースになったライブラリをよく見かけます。Sonnet2でこれを実現する仕組みがsnt.onceです。これは、対象の関数が1回しか実行されないようにするデコレータです。

  • 初期化の処理のメソッドを作る
  • そのメソッドにsnt.onceのデコレータをつける
  • __call__の先頭でそのメソッドを呼ぶ

という流れで使うと分かりやすいプログラムになると思います。

Sonnet2とSonnet1の違い(tf.Module)

TensorFlow内部にたくさんあったニューラルネット用の高レベルAPIをtf.kerasに統一したのがTF2の大きな特徴です。この裏でTensorFlowにtf.Moduleが導入されました(TF1系でも1.14から利用できます)。実を言うと、snt.Moduletf.Moduleの機能をちょっといじった程度のものです。そのため、Sonnetのコア機能がTF2に吸収されたと言えます。
Sonnet1とSonnet2のコア機能のソースコードを比較するとかなりすっきりしたことが分かります。

|バージョン|コア機能|ソースコードの特徴|ソースコード|
|---+---+---+---|
|Sonnet1|snt.AbstractModule|tf.make_templateやcustom_getterを駆使、メソッドが多い|sonnet/python/modules/base.py(コミット: ee1731e87f)|
|Sonnet2|snt.Module|すっきり|sonnet/src/base.py(コミット: 84817e56a2)|

機械学習・深層学習で得られるモデルは、言ってみれば、(大量の)パラメータで構成された関数です。この単位(パラメータ+関数)でニューラルネットを管理する機能を持ったライブラリが多く、tf.ModuleでTensorFlowにもそれが導入されたというわけです。
実は、tf.keras.Modeltf.Moduleを継承しています。

SonnetとKerasの違い(訓練シナリオの有無)

ここまで読んでくれた読者は「もうTensorFlowの高レベルAPIはKerasって決まったんだから、それ以外のライブラリを勧めないでくれ!!」みたいに怒り出したりしないでしょうが、Kerasとの違いは気になると思います。両者の違いはカバー範囲にあります。Kerasがネットワーク記述から訓練ループまでサポートしているのに対して、Sonnetが扱うのは個々の部品です。
典型的な訓練パターンであれば、ネットワークをtf.keras.Modelの形で用意し、compileメソッドでoptimizerをセットしてfitfit_generatorを用いて訓練するというのがKerasのやり方です。Kerasの優れた点は、callbackを使って訓練ループをカスタマイズできることです。
一方で、Sonnetの背景には多種の訓練パターンに取り組んでいるDeepMindの研究者の事情があります。READMEには以下のように書かれています。

Sonnet has been designed and built by researchers at DeepMind. It can be used to construct neural networks for many different purposes (un/supervised learning, reinforcement learning, ...). We find it is a successful abstraction for our organization, you might too!

違いをまとめるとこんな感じでしょうか。

|ライブラリ|サポート範囲|想定ユーザー|訓練シナリオ|
|---+---+---+---|
|Keras|ネットワークから訓練まで多機能|典型的な訓練を行う幅広い人|fit等、callbackでカスタマイズ|
|Sonnet|ネットワーク|DeepMindの研修者のような多様な訓練を行う人|別途用意|

深層学習とモジュール化

相変わらず深層学習の研究は盛んで、論文の再現用コード・訓練ツール・ライブラリ等の形で次々公開されています。近年の手法では、特に、

  • アーキテクチャの複雑化
  • マルチGPUの活用法の多様化
  • データセットの個性化・多様化

が進んでいます。そのような状況のせいか、論文の再現用コードから新しい手法を部分的に取り込もうにも、容易でない場合が多く見受けられます。人気の分野であれば、訓練ツールやライブラリとして整備されますが、そうでないと強い人が力技で書いたコードと魔改造を重ねた秘伝のたれが蔓延しかねません。それを防ぐ上手いモジュール化を議論し共有したいです(この章の内容は「TensorFlow2.0の傾向と対策 〜 TensorFlowの消失」で勧めたリファクタリングの具体案にあたり、当時書こうとしてさぼっていたものです)。

提案は、(訓練)シナリオのモジュールを独立させるという案です。

訓練シナリオの例

例えば、次の疑似コードようにモデル等を引数で動作させる仕組みで抽象化する感じです。

class GANTrainSenario:
    def __init__(self, step, save_interval, save_dir):
        """
        訓練ループに関わる設定をまず行う
        """
        self.step = step
        self.save_interval = save_interval
        self.save_dir = save_dir

    def __call__(self, model_g: tf.Module, model_d: tf.Module, opt_g, opt_d, loss_g, loss_d, data_loader):
        ```
        ループと独立したものはこの引数で受け取る
        ```
        for i in range(1, self.step + 1):
            input_g = data_loader.get("g")  # data_loaderは雰囲気で解釈してください
            opt_g.minimize(lambda: loss_g(model_d(model_g(input_g))), variables=model_g.variables)

            input_g, input_d, expected = data_loader.get("gd")
            input_d_concat = concat(model_g(input_g), input_d)
            opt_g.minimize(lambda: loss_d(model_d(input_d_concat), expected), variables=model_d.variables)

            if i % self.save_interval == 0:
                save(model_g, model_d, save_dir)

例えば、次のようなディレクトリ構造でファイルを分けるだけでも、部分的に利用したい人にとって大いに助けになると思います。

.
├── oreore_net
│   ├── __init__.py
│   ├── dataset.py
│   ├── loss.py
│   ├── net.py
│   └── scenario.py
├── README.md
├── samples  # ここに論文の再現用コードを置くなど
└── setup.py  # `pip install`できるようにする

SonnetやKerasを見ながら考える深層学習のモジュール

深層学習はそれなりに多くの要素で構成されます。それらは凡そ次のように分類できると考えています。

  • 計算系(ニューラルネット, optimizer, loss)
  • リソース系(計算リソース, データセット, 作業領域, pretrainな初期値)
  • シナリオ系(訓練, 推論)
  • 評価系
  • エクスポート系(TensorFlowLite, TensorRT, ...)

Sonnetは計算系で、Kerasは全部入りといった感じです。

最後に

Sonnet2はまだベータ版であり、正式リリースが待ち遠しいです。
深層学習のモジュールの分類をしましたが、その中でも、シナリオ系・評価系は他に比べて目立たず充実していないように思います(主観的ですが)。シナリオ系・評価系のライブラリが充実すると世界中の研究がはかどるので、そういうのが増えて欲しいです。

8
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?