$\def\bm{\boldsymbol}$
概要
- オンラインで開催している『モデルベース深層学習と深層展開』読み会で得られた知見や気づきをメモしていく
- ついでに、中身の理解がてらJuliaサンプルコードをPythonに書き直したコードを晒していく
- 自動微分ライブラリにはJAXを使用する
第1回
大まかな内容
- 深層展開とは
- アイデアの中核
- 深層展開とモデルベース深層学習の関係
- 深層学習の基礎技術の理解
- MLPの数式理解
- ミニバッチ学習法、確率的勾配法、誤差逆伝搬法
- 自動微分と計算グラフの理解
議論になったこと
1.3.2節 あたり
- モデルベースでやれるなら、それで済まないのか
- ->どんなにうまくやってもモデル化誤差が残る。また、経年劣化等で真のモデルも変化する。
- 深層展開では式の形を作ることに相当することもできるのだろうか
- ->学習対象のパラメータの設計次第でそれに近いことはできるのではないか?(多項式の係数など)
- でもあくまでも本質は反復アルゴリズム内のパラメータの学習
- ->学習対象のパラメータの設計次第でそれに近いことはできるのではないか?(多項式の係数など)
2.1節 あたり
- 「パラメトリックモデルの学習後パラメータ$\theta^*$から未知のシステムの確率的特徴をとらえることが可能となる」というのはどういう意味か
- ->結論出ず。(後述の、「個人的気づき」に後で思ったことを記述)
2.2節 あたり
-
MLP-Mixerってどんなもん?
- ->入力データを細かい領域に分け、MLPなどから成る並列したモジュールに食わせて大規模なデータセットに対しても効率的に特徴を学習し、高度なパターンを識別しようというモデル。だが、実用上はCNNなどのほうが性能が出やすいっぽい。
2.3節 あたり
- Optunaってどんなもの?
- ハイパーパラメータなどを半自動的に求める便利なツール
- APIがかなり優秀で使いやすい
- ハイパーパラメータなどを半自動的に求める便利なツール
2.4節 あたり
- 式(2.18)は誤植では?$\frac{\partial L}{\partial \theta_n}$が正しいのでは?
- ->多分そうじゃね
2.4.2節 あたり
- 「自動微分は効率の良い数値微分法のひとつ」的な事が書いてるが、自動微分は数値微分の一種なのか、別物ではないのか
- ->教科書の"数値微分"は数値解析で微分を求める方法群を意味していて、”前方差分”とかの狭義の数値微分ではないのでは
- 双対数ってよくわからん
- ->虚数みたいに、ある特定の演算で好ましい振る舞いをするようにデザインされた仮想数
- ->双対数を使うと、積の微分の計算が簡単な四則演算に置き換えられる(式(2.40)の第二項に注目)
- 順方向自動微分の方にだけ双対数が利用される理屈は?
- ->結論出ず。
2.4.3節 あたり
- NNの誤差逆伝搬は逆方向自動微分の1本道版だが、NNの計算だと必ず1本道になるということか、であれば何故そうなるのか
- ->結論出ず。
個人的な気づきなど
- 2.1節で残った疑問(未知のシステムの確率的特徴うんぬん)についての再考
- 「学習済みパラ$p(\theta|\mathcal{D})$と尤度$p(y|\theta,x)$さえあれば、未知の入出力関係 $p(y^* |x^*)$を周辺化で求める事できるよ」的な話かなと思った
プログラムでの理解
- サンプルコードがほぼ出てこない回だったので、番外編的にJAXでの計算グラフの可視化をやってみる
- プログラムの全文はこの辺に上げている
問題設定
教科書2.4.2節(2.28)式
$$
f(x_1,x_2) = (x_1 + x_2)x_1^2$$
の$x_1$による偏微分をやってみる
Pythonで実装
必要ライブラリインポート
import jax
from jax import grad
微分対象の関数の定義
@jax.jit
def f(x1, x2):
return (x1 + x2) * x1**2
df = grad(f, argnums=0) #x1での偏微分
@jax.jit
はjust-in-timeコンパイラのデコレーター。詳しくは次回移行またどこかで出てきたときに触れる。別に今回は付けなくてもいいが、これのアリナシで計算グラフや出力される図も変わってくるので注意。
微分前の関数の計算グラフの描写
jaxのxla_computation()
を利用すると、計算グラフを外部ライブラリでも扱える形で出力できる
z = jax.xla_computation(f)(1.0, 2.0)
save_name = "./f_cg.dot"
with open(save_name, "w") as file:
file.write(z.as_hlo_dot_graph())
.dot形式で保存したファイルを画像にするには、bash等でgraphvizのコマンドを使う
dot f_cg.dot -Tpng > f_cg.png
以下のような、画像が生成される。
自乗が普通の掛け算で表現されてたりと、微妙な違いはあるが教科書図2.6と同等の図が得られた。
偏微分された関数の計算グラフの描写
同様に、
z = jax.xla_computation(df)(1.0, 2.0)
save_name = "./df_cg.dot"
with open(save_name, "w") as file:
file.write(z.as_hlo_dot_graph())
とすると、以下のようなグラフが描写される
図だけだとよくわからないので、テキスト形式でもプリントしてみる。
print(jax.make_jaxpr(df)(1.0, 2.0))
{ lambda ; a:f32[] b:f32[]. let
_:f32[] c:f32[] d:f32[] e:f32[] = pjit[
name=f
jaxpr={ lambda ; f:f32[] g:f32[]. let
h:f32[] = add f g
i:f32[] = integer_pow[y=2] f
j:f32[] = integer_pow[y=1] f
k:f32[] = mul 2.0 j
l:f32[] = mul h i
in (l, i, h, k) }
] a b
m:f32[] = pjit[
name=f
jaxpr={ lambda ; n:f32[] o:f32[] p:f32[] q:f32[]. let
r:f32[] = mul o q
s:f32[] = mul r p
t:f32[] = mul q n
u:f32[] = add_any s t
in (u,) }
] c d e 1.0
in (m,) }
前半は前向き計算で、教科書(2.31)~(2.39)および(2.41),(2.42)式らへんにあたると思う。
後半が、逆方向計算で(2.43)式にあたる?
grad
関数は教科書でいうところの逆方向自動微分で実装されているようだ。
jaxには、多次元の偏微分をするための関数jacfwd
とjacrev
があり、コチラはそれぞれ順方向自動微分、逆方向自動微分のよう。コード全文の方ではそちらでも少し遊んでいる。
その他
- 今回は結構基本的な内容が多かったためか、理論に対する質問があまり上がらなかった。ファシリテーションもちょっと難しかった。。。がんばろう