はじめに
あけましておめでとうございます。今日から2021年です。新年ということでですね、haikuをよんでいこうと思います。
haiku(dm-haiku)とは
haiku(dm-haiku)はdeepmind製のjax向け深層学習モデル記述ライブラリです。まだ開発途中でβ版として公開されています。
sonnetをご存じの方には、haikuはsonnetのjax版と言うと分かりやすいと思います。sonnet → sonnet2 → haiku と進化する過程を見ると面白いです。
先月に、DeepMindのブログで、DeepMindがjaxを使っているよ、色々ライブラリ作って公開しているよ、という旨の内容が紹介されていました。また、ODE-GAN(論文、コード)という研究にhaikuが使われています(この論文もDeepMindの方々が執筆)。ODE-GANは、ルンゲクッタ法や、パラメータの勾配に対する正則化のように、柔軟な勾配計算を要求する手法で、tensorflowで書くと大変そうな手法なのですが、ODE-GANのコードはhaikuを使ってシンプルに書かれています。これが、haikuについて調べてみようと思った経緯です。
haikuを使った例
まず、haikuの機能の紹介に、以下の例を用意しました。コメントの[A]から[G]の部分がhaikuの特徴的な部分です。コードの後でこれらを解説します。
import haiku as hk # [A] haikuのimport
import jax
import jax.numpy as jnp
# 典型的なモジュールの書き方
class MyModule(hk.Module):
def __init__(self, n_hidden_layer, dim_hidden, dim_out, name=None):
super().__init__(name=name) # [B] name scopeを作成
self._hidden_layers = [hk.Linear(dim_hidden) for _ in range(n_hidden_layer)]
self._dim_hidden = dim_hidden
self._last_layer = hk.Linear(dim_out)
def __call__(self, x): # [C] 計算内容の記述
for i, layer in enumerate(self._hidden_layers):
# [D] 変数の取得
b = hk.get_parameter(f"b_{i}", shape=[self._dim_hidden], dtype=x.dtype, init=jnp.zeros)
x = jax.nn.relu(layer(x)) + b
return self._last_layer(x)
# transformを使ったコンパクトな書き方
def get_my_model_by_function(n_hidden_layer, dim_hidden, dim_out):
def my_module_by_function(x):
for i in range(n_hidden_layer):
b = hk.get_parameter(f"b_{i}", shape=[dim_hidden], dtype=x.dtype, init=jnp.zeros)
x = jax.nn.relu(hk.Linear(dim_hidden)(x)) + b
return hk.Linear(dim_out)(x)
return hk.transform(my_module_by_function) # [E] 関数をモデルに変換する
# transformを使ったモデルの使い方
model = get_my_model_by_function(5, 100, 10)
key = hk.PRNGSequence(20210101) # [F] 乱数生成のキーの作成が簡単にできる便利イテレータ
x = jnp.ones([1, 30])
params = model.init(next(key), x) # [G] transformで作ったモデルにはinitとapplyの関数がある
y = model.apply(params, None, x)
例の説明(import)
[A]のようにhaikuをhkと省略してimportするのがおすすめのようです。haikuを使う場合、同時にjaxやjax.numpyやnumpyも一緒にimportすることが多いでしょう。jax.numpyはjnpと省略される例が多いです。
例の説明(hk.Module)
モジュール(全結合層や畳み込み層などの層)の記述にhk.Module
を使います。**hk.Module
がhaikuの最も基礎的な機能です。**モジュールの役割は、以下の2点です。
- そのモジュール用の名前空間を作成する(内部で作った変数等にそのモジュールのものと分かる名前が付く)
- モジュールの計算内容を記述する
hk.Module
は、sonnetのAbstractModule
やsonnet2のModule
に相当するものです。インターフェースはsonnet2のModule
と同じです。2系のtensorflowのtf.Module
も似た機能です。
1番はコードの[B]の部分です(親クラスでいい感じにやってくれます)。tensorflowでいうところのvariable_scope
やname_scope
にあたる機能です。sonnet(v1)の時代から似たものがあり、sonnet(v1)ではvariable_scope
とcustom_getter
で実現されていました。tensorflow2系からtf.Module
ができたため、sonnet2のModule
はtf.Module
をベースに作成されています。jaxにはname_scope
のようなものがないため、haikuはそれに相当するものを用意しています。
2番はコードの[C]の部分です。pythonでは__call__
というメソッドがあると、関数のようにmodel(x)
のように使えます(__call__
メソッドの内容を実行します)。ここもsonnet2の典型的な使い方と同じです。ただし、sonnet2では、変数の作成にonce
という関数でデコレートしたメソッドで、最初の一回だけパラメータ作成のために呼ぶメソッドを用意するパターンを採用していました。tensorflowが2系になったときにget_variable
等の関数が無くなり、その影響でこのような面倒が発生していました。haikuでは、get_parameter
という関数があり、最初に呼ばれる場合を意識することなく変数を扱うことができます。
例の説明(hk.transform)
モジュールを組み合わせてモデルを作るわけですが(そのモデルを一つのモジュールで書いてもいいですが)、扱いやすいやすさ・インターフェースの整理のために、transform
やtransform_with_state
という関数があります(コードの[E]の部分)。簡単のためtransform
だけに話を絞ります。transform
は(モジュールを使う)関数をモデルに変換するものです。返り値の型はTransformed
という型ですが、ただのnamedtupleです。Transformed
はinit
とapply
という属性を持ちます。**init
はモデルの初期化を行う関数で、モデルの中のパラメータの値を返します。逆にapply
はパラメータと入力を引数に与えて、そのモデルの計算結果を返す関数です。**コードの[G]の部分が該当します。init
の引数は、乱数生成のキー・入力の順で、apply
の引数は、モデルのパラメータ・乱数生成のキー・入力の順です。**パラメータを引数にするあたりが特徴的です。**冒頭のODE-GANのような手法の運用に便利なインターフェースです。
例の説明(その他)
jaxを触ってみると分かることですが、jaxの乱数の仕組みは他(numpyやtensorflow)と比べて複雑です。乱数生成のキーを生成する仕組みを使って・・・というような処理が必要です。高度なコントロールができるのが利点ですが、利用するには簡単なものがあると便利です。haikuのPRNGSequence
はそれを簡単に使えるものです(コードの[F]の部分)。PRNGSequence
を使うと乱数生成のキーを簡単に用意できるため、苦労が減ることでしょう。
haikuのよみ方
haikuの基本構造
haikuだけでなく、前身のsonnet・sonnet2でもそうですが、これらはシンプルな構造のライブラリになることを目指しています。その構造は少量のコア機能とそれを使った具体的な部品からなります。haikuの場合具体的には以下のものから構成されます。
役割 | 内容 | 補足 |
---|---|---|
コア機能 | hk.Module |
|
部品 |
hk.Linear など |
hk.Module を継承して多くの層が用意されている |
便利機能 |
hk.transform やhk.PRNGSequence など |
hk.transform はコア機能かも |
sonnetも似た構造をしています(「どうせ同じでしょ」と思ってhaikuをよんだらやっぱりそうでした)。
どこをよむといいか
どこをよむといいかなんて、よむ人次第ですが、「haikuの仕組みが知りたい」「haikuのどの仕組みが他ライブラリとの違いか」といった部分に興味があったので、コア機能を中心によみました。逆にhaikuを使って自作の層を作ってみたい人は、部品を中心によむとよいでしょう。
haikuをよむ
前章で説明した通り、コア機能に興味があるので、そこを中心によんで分かったことを紹介していきます。
また、そのときのバージョンのソースコードのURLは、https://github.com/deepmind/dm-haiku/tree/300e6a40be31e35940f0725ae7ed3457b737a5a3です。
ディレクトリ構造
haikuの主要なファイル・ディレクトリの構造は以下の通りです。基本的にプログラムはhaiku/_src/XX.py
に書き、haiku
ディレクトリ直下のpythonファイルで公開APIの部分だけをインポートしています。
.
├── WORKSPACE
├── haiku
│ ├── BUILD
│ ├── __init__.py
│ ├── _src
│ │ ├── base.py
│ │ ├── data_structure.py
│ │ ├── stateful.py
│ │ ├── module.py
│ │ ├── transform.py
│ │ └── typing.py
│ ├── data_structures.py
│ ├── experimental.py
│ ├── initializers.py
│ ├── nets.py
│ ├── pad.py
│ ├── testing.py
│ └── typing.py
├── requirements-jax.txt
├── requirements-test.txt
├── requirements.txt
├── setup.py
└── test.sh
haiku/_src
ディレクトリの中身はコア機能に関係するものだけを書きました。
haiku/_src/base.py
haiku/_src/base.py
には、name_scopeを支える仕組みやPRNGSequenceが書かれています。
name_scopeには、このスコープ中のこのスコープの・・・というようなスコープの階層構造があります。name_scopeでは今、どのスコープにいるかを把握しないといけません。そのために「今どこのスコープにいるか」を意味するグローバル変数を用意しています。具体的には、stackとして実現しており、スコープに入ったらpush、スコープから出たらpopする。スコープを意識した処理をするときは、例えばstackの最後を見て処理をするであったり、stackの先頭から順に何かを適用したりします(再帰で似たことを実現しているコードがなん箇所かにあります)。雑にname_scopeと書きましたが、このような階層構造を、変数・状態変数(BatchNormalizationの移動平均等)・乱数・モジュール・名前それぞれのためにstackを用意していました。
haiku/_src/data_structure.py
haiku/_src/data_structure.py
にはhaikuのコア機能内のために使われる基礎的なデータ構造が書かれています。主要なものはStackとFlatMappingです。
Stackはbase.pyで説明したように、階層構造のどこにいるのかを表現するために使われます。
FlatMappingはモデルの変数一覧等に用いられます。haikuのモデルの変数は階層構造があります。その階層構造を表現するためにdictのかわりにこのFlatMappingを使います。Transformed
のinit
の返り値などで用いられます。FlatMappingはjax.tree_XXというような名前のjaxの関数と一緒に使うために準備されたものです。jax.tree_XXという関数を使って、パラメータの更新や、勾配同士の演算をしたり(ルンゲクッタ法のような)します。
haiku/_src/stateful.py
haiku/_src/stateful.py
はtransofrom
等の中でjaxの関数とhaikuのスコープを整合的に運用するための機能が実装されています。jaxの基礎的な関数をhaiku用にラップしたものがあります。細かなパーツがたくさんある感じなのと、そこまでjaxのあれこれを知っているわけではないので、現状は深く踏み込めなくてこのファイルはこれくらいで諦めました。
haiku/_src/module.py
haiku/_src/module.py
は名前からわかるとおり、Module
が書かれています。また、Module
の基本機能のname_scope自体もこのファイルにあります。
haiku/_src/transform.py
haiku/_src/transform.py
も名前からわかるとおり、transform
が実装されています。実装を見ていくと、少したらいまわしにされますが、init
とapply
の二つの関数を適切なスコープで作ってTransformed
として返すだけの関数です。
haiku/_src/typing.py
haiku/_src/typing.py
はよく使う型を特別に変数に代入しているだけです。一箇所気になるのは、次のような行です。
Module = typing._ForwardRef("haiku.Module")
Module
が最も基礎的な機能なので、typing.pyで型ヒント用の変数を用意したいのはよくわかります。forward referenceという実験的な機能があるようで、それが使えたら使うということをしています。
最後に
ということでね、haikuをよみました。
2020年はtransformerの高速化や画像認識への適用、GPT3、alphafoldの活躍といった話題がありましたが、2021年はどんな技術・手法が登場するのでしょうか。今はtensorflow/pytorchの2強状態ですが、jaxやhaikuじゃないと上手く書けないような高度な手法とかが登場したりするんでしょうかね?東京オリンピックもどうなるんでしょうね?開会式などで日本の技術力をアピールするような深層学習を使った何かがみれたりするんでしょうかね?
ポエム
新年や 離れて祝い haikuよむ