11
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

dm-haikuをよむ

Last updated at Posted at 2021-01-01

はじめに

あけましておめでとうございます。今日から2021年です。新年ということでですね、haikuをよんでいこうと思います。

haiku(dm-haiku)とは

haiku(dm-haiku)はdeepmind製のjax向け深層学習モデル記述ライブラリです。まだ開発途中でβ版として公開されています。
sonnetをご存じの方には、haikuはsonnetのjax版と言うと分かりやすいと思います。sonnet → sonnet2 → haiku と進化する過程を見ると面白いです。
先月に、DeepMindのブログで、DeepMindがjaxを使っているよ、色々ライブラリ作って公開しているよ、という旨の内容が紹介されていました。また、ODE-GAN(論文コード)という研究にhaikuが使われています(この論文もDeepMindの方々が執筆)。ODE-GANは、ルンゲクッタ法や、パラメータの勾配に対する正則化のように、柔軟な勾配計算を要求する手法で、tensorflowで書くと大変そうな手法なのですが、ODE-GANのコードはhaikuを使ってシンプルに書かれています。これが、haikuについて調べてみようと思った経緯です。

haikuを使った例

まず、haikuの機能の紹介に、以下の例を用意しました。コメントの[A]から[G]の部分がhaikuの特徴的な部分です。コードの後でこれらを解説します。

import haiku as hk  # [A] haikuのimport
import jax
import jax.numpy as jnp


# 典型的なモジュールの書き方
class MyModule(hk.Module):
  def __init__(self, n_hidden_layer, dim_hidden, dim_out, name=None):
    super().__init__(name=name)  # [B] name scopeを作成
    self._hidden_layers = [hk.Linear(dim_hidden) for _ in range(n_hidden_layer)]
    self._dim_hidden = dim_hidden
    self._last_layer = hk.Linear(dim_out)

  def __call__(self, x):  # [C] 計算内容の記述
    for i, layer in enumerate(self._hidden_layers):
      # [D] 変数の取得
      b = hk.get_parameter(f"b_{i}", shape=[self._dim_hidden], dtype=x.dtype, init=jnp.zeros)
      x = jax.nn.relu(layer(x)) + b
 
    return self._last_layer(x)


# transformを使ったコンパクトな書き方
def get_my_model_by_function(n_hidden_layer, dim_hidden, dim_out):
  def my_module_by_function(x):
    for i in range(n_hidden_layer):
      b = hk.get_parameter(f"b_{i}", shape=[dim_hidden], dtype=x.dtype, init=jnp.zeros)
      x = jax.nn.relu(hk.Linear(dim_hidden)(x)) + b
    return hk.Linear(dim_out)(x)
  return hk.transform(my_module_by_function)  # [E] 関数をモデルに変換する


# transformを使ったモデルの使い方
model = get_my_model_by_function(5, 100, 10)
key = hk.PRNGSequence(20210101)  # [F] 乱数生成のキーの作成が簡単にできる便利イテレータ
x = jnp.ones([1, 30])
params = model.init(next(key), x)  # [G] transformで作ったモデルにはinitとapplyの関数がある
y = model.apply(params, None, x)

例の説明(import)

[A]のようにhaikuをhkと省略してimportするのがおすすめのようです。haikuを使う場合、同時にjaxやjax.numpyやnumpyも一緒にimportすることが多いでしょう。jax.numpyはjnpと省略される例が多いです。

例の説明(hk.Module)

モジュール(全結合層や畳み込み層などの層)の記述にhk.Moduleを使います。**hk.Moduleがhaikuの最も基礎的な機能です。**モジュールの役割は、以下の2点です。

  1. そのモジュール用の名前空間を作成する(内部で作った変数等にそのモジュールのものと分かる名前が付く)
  2. モジュールの計算内容を記述する

hk.Moduleは、sonnetのAbstractModuleやsonnet2のModuleに相当するものです。インターフェースはsonnet2のModuleと同じです。2系のtensorflowのtf.Moduleも似た機能です。
1番はコードの[B]の部分です(親クラスでいい感じにやってくれます)。tensorflowでいうところのvariable_scopename_scopeにあたる機能です。sonnet(v1)の時代から似たものがあり、sonnet(v1)ではvariable_scopecustom_getterで実現されていました。tensorflow2系からtf.Moduleができたため、sonnet2のModuletf.Moduleをベースに作成されています。jaxにはname_scopeのようなものがないため、haikuはそれに相当するものを用意しています。
2番はコードの[C]の部分です。pythonでは__call__というメソッドがあると、関数のようにmodel(x)のように使えます(__call__メソッドの内容を実行します)。ここもsonnet2の典型的な使い方と同じです。ただし、sonnet2では、変数の作成にonceという関数でデコレートしたメソッドで、最初の一回だけパラメータ作成のために呼ぶメソッドを用意するパターンを採用していました。tensorflowが2系になったときにget_variable等の関数が無くなり、その影響でこのような面倒が発生していました。haikuでは、get_parameterという関数があり、最初に呼ばれる場合を意識することなく変数を扱うことができます。

例の説明(hk.transform)

モジュールを組み合わせてモデルを作るわけですが(そのモデルを一つのモジュールで書いてもいいですが)、扱いやすいやすさ・インターフェースの整理のために、transformtransform_with_stateという関数があります(コードの[E]の部分)。簡単のためtransformだけに話を絞ります。transformは(モジュールを使う)関数をモデルに変換するものです。返り値の型はTransformedという型ですが、ただのnamedtupleです。Transformedinitapplyという属性を持ちます。**initはモデルの初期化を行う関数で、モデルの中のパラメータの値を返します。逆にapplyはパラメータと入力を引数に与えて、そのモデルの計算結果を返す関数です。**コードの[G]の部分が該当します。initの引数は、乱数生成のキー・入力の順で、applyの引数は、モデルのパラメータ・乱数生成のキー・入力の順です。**パラメータを引数にするあたりが特徴的です。**冒頭のODE-GANのような手法の運用に便利なインターフェースです。

例の説明(その他)

jaxを触ってみると分かることですが、jaxの乱数の仕組みは他(numpyやtensorflow)と比べて複雑です。乱数生成のキーを生成する仕組みを使って・・・というような処理が必要です。高度なコントロールができるのが利点ですが、利用するには簡単なものがあると便利です。haikuのPRNGSequenceはそれを簡単に使えるものです(コードの[F]の部分)。PRNGSequenceを使うと乱数生成のキーを簡単に用意できるため、苦労が減ることでしょう。

haikuのよみ方

haikuの基本構造

haikuだけでなく、前身のsonnet・sonnet2でもそうですが、これらはシンプルな構造のライブラリになることを目指しています。その構造は少量のコア機能とそれを使った具体的な部品からなります。haikuの場合具体的には以下のものから構成されます。

役割 内容 補足
コア機能 hk.Module
部品 hk.Linearなど hk.Moduleを継承して多くの層が用意されている
便利機能 hk.transformhk.PRNGSequenceなど hk.transformはコア機能かも

sonnetも似た構造をしています(「どうせ同じでしょ」と思ってhaikuをよんだらやっぱりそうでした)。

どこをよむといいか

どこをよむといいかなんて、よむ人次第ですが、「haikuの仕組みが知りたい」「haikuのどの仕組みが他ライブラリとの違いか」といった部分に興味があったので、コア機能を中心によみました。逆にhaikuを使って自作の層を作ってみたい人は、部品を中心によむとよいでしょう。

haikuをよむ

前章で説明した通り、コア機能に興味があるので、そこを中心によんで分かったことを紹介していきます。
また、そのときのバージョンのソースコードのURLは、https://github.com/deepmind/dm-haiku/tree/300e6a40be31e35940f0725ae7ed3457b737a5a3です。

ディレクトリ構造

haikuの主要なファイル・ディレクトリの構造は以下の通りです。基本的にプログラムはhaiku/_src/XX.pyに書き、haikuディレクトリ直下のpythonファイルで公開APIの部分だけをインポートしています。

.
├── WORKSPACE
├── haiku
│   ├── BUILD
│   ├── __init__.py
│   ├── _src
│   │   ├── base.py
│   │   ├── data_structure.py
│   │   ├── stateful.py
│   │   ├── module.py
│   │   ├── transform.py
│   │   └── typing.py
│   ├── data_structures.py
│   ├── experimental.py
│   ├── initializers.py
│   ├── nets.py
│   ├── pad.py
│   ├── testing.py
│   └── typing.py
├── requirements-jax.txt
├── requirements-test.txt
├── requirements.txt
├── setup.py
└── test.sh

haiku/_srcディレクトリの中身はコア機能に関係するものだけを書きました。

haiku/_src/base.py

haiku/_src/base.pyには、name_scopeを支える仕組みやPRNGSequenceが書かれています。
name_scopeには、このスコープ中のこのスコープの・・・というようなスコープの階層構造があります。name_scopeでは今、どのスコープにいるかを把握しないといけません。そのために「今どこのスコープにいるか」を意味するグローバル変数を用意しています。具体的には、stackとして実現しており、スコープに入ったらpush、スコープから出たらpopする。スコープを意識した処理をするときは、例えばstackの最後を見て処理をするであったり、stackの先頭から順に何かを適用したりします(再帰で似たことを実現しているコードがなん箇所かにあります)。雑にname_scopeと書きましたが、このような階層構造を、変数・状態変数(BatchNormalizationの移動平均等)・乱数・モジュール・名前それぞれのためにstackを用意していました。

haiku/_src/data_structure.py

haiku/_src/data_structure.pyにはhaikuのコア機能内のために使われる基礎的なデータ構造が書かれています。主要なものはStackとFlatMappingです。
Stackはbase.pyで説明したように、階層構造のどこにいるのかを表現するために使われます。
FlatMappingはモデルの変数一覧等に用いられます。haikuのモデルの変数は階層構造があります。その階層構造を表現するためにdictのかわりにこのFlatMappingを使います。Transformedinitの返り値などで用いられます。FlatMappingはjax.tree_XXというような名前のjaxの関数と一緒に使うために準備されたものです。jax.tree_XXという関数を使って、パラメータの更新や、勾配同士の演算をしたり(ルンゲクッタ法のような)します。

haiku/_src/stateful.py

haiku/_src/stateful.pytransofrom等の中でjaxの関数とhaikuのスコープを整合的に運用するための機能が実装されています。jaxの基礎的な関数をhaiku用にラップしたものがあります。細かなパーツがたくさんある感じなのと、そこまでjaxのあれこれを知っているわけではないので、現状は深く踏み込めなくてこのファイルはこれくらいで諦めました。

haiku/_src/module.py

haiku/_src/module.pyは名前からわかるとおり、Moduleが書かれています。また、Moduleの基本機能のname_scope自体もこのファイルにあります。

haiku/_src/transform.py

haiku/_src/transform.pyも名前からわかるとおり、transformが実装されています。実装を見ていくと、少したらいまわしにされますが、initapplyの二つの関数を適切なスコープで作ってTransformedとして返すだけの関数です。

haiku/_src/typing.py

haiku/_src/typing.pyはよく使う型を特別に変数に代入しているだけです。一箇所気になるのは、次のような行です。

Module = typing._ForwardRef("haiku.Module")

Moduleが最も基礎的な機能なので、typing.pyで型ヒント用の変数を用意したいのはよくわかります。forward referenceという実験的な機能があるようで、それが使えたら使うということをしています。

最後に

ということでね、haikuをよみました。

2020年はtransformerの高速化や画像認識への適用、GPT3、alphafoldの活躍といった話題がありましたが、2021年はどんな技術・手法が登場するのでしょうか。今はtensorflow/pytorchの2強状態ですが、jaxやhaikuじゃないと上手く書けないような高度な手法とかが登場したりするんでしょうかね?東京オリンピックもどうなるんでしょうね?開会式などで日本の技術力をアピールするような深層学習を使った何かがみれたりするんでしょうかね?

ポエム

新年や 離れて祝い haikuよむ

11
5
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?