Help us understand the problem. What is going on with this article?

深層学習の理論を眺める:Neural Tangent Kernelと平均場理論

「深層学習はその原理的な背景はよくわからない」,「深層学習は計算機の力でぶん殴ってる」,「深層学習は数理的な側面があまり整備されていない」…こういう話をみなさま一度は耳にした事があるのではないでしょうか?
こういう課題感に対して,ここ数年で深層学習の理論的な研究への注目度が上がっている気がします.今回は,深層学習を支える(かもしれない)ここ数年で特に注目を浴びた理論について紹介します.

はじめに

Advent Calendarな皆様,こんにちはこんばんはおはようございます.クロステック開発部に所属している田中と申します.昨年度は先進技術研究所でAdvent Calendarをひっそりと書いていましたが,先進技術研究所とサービスイノベーション部と5Gイノベーション推進室を跨いだ組織再編があり,クロステック開発部に異動となりまして,今年はR&Dの3部署(クロステック開発部・サービスイノベーション部・移動機開発部)合同でAdvent Calendarをすることになりました.

開発部の人達が記事を書くということもあり,作ってみた系・やってみた系の記事が多くなっているのではないでしょうか.
そこで,私の記事ではいわゆる作ってみた系とは対局と言っても良い,Deep Learningの理論的側面について書きたいと思います.と言っても,統計的学習理論のバリバリの専門家向けではなく,エンジニア等の現場にいる人に向けて記事を書きたいと思います.普段使っている Deep Learning の裏側にどんな理論的な広がりがあるのかを,ふんわりとでも感じてもらえると嬉しいです.

この記事では,Neural Tangent Kernelという理論と,平均場理論の2つを紹介します.ただ,これらの紹介を論文に沿って解説すると非常に難しいので,具体的な回帰問題に限定して話を進めていきます.
また,紹介する理論を知っているとどういう時に役立つのかも,具体例をあげて説明します.

Neural Tangent Kernel

まずは Neural Tangent Kernel というものについて解説します.
Neural Tangent Kernel は関数空間における内積によって定義される量で,これを使うと Deep Neural Network の学習の変化を線形の常微分方程式で記述できるようになります.おそらく分かりにくいと思うので例を出しますと,高校の物理において物体の運動を運動方程式
$$
m\frac{d^2 x}{dt^{2}} = F
$$
で表せるというのと同じ感覚です.
以下では,この Neural Tangent Kernel について回帰の場合を考えて解説していきます.

Neural Tangent Kernel って何?

$L$層のDNN $f: \boldsymbol{x}\to y; \mathbb{R}^d \to \mathbb{R}$ を学習データ $\left\{ (\boldsymbol{x}_i, y_i)\right\}_{n=1}^{N}$ で学習する場合を考えます.目的変数の定義域を $\mathbb{R}$ としているので,今回は損失関数として
$$
\ell (\boldsymbol{x}, y)=\frac{1}{2} \left\{ y - f(\boldsymbol{x}; \boldsymbol{\theta}) \right\}^2
$$
を使います.ここで,重み行列やバイアスなどの学習可能なパラメータを全部まとめて$\boldsymbol{\theta}\in \mathbb{R}^p$で表しています.

ニューラルネットを確率的勾配降下法で学習するときには,

\begin{align}
\Delta \boldsymbol{\theta} =& -\eta \frac{\partial \ell}{\partial \boldsymbol{\theta}} \\
=& -\eta \{ y - f(\boldsymbol{x}; \boldsymbol{\theta})\} \frac{\partial}{\partial \boldsymbol{\theta}}f (\boldsymbol{x}; \boldsymbol{\theta})
\end{align}

を用いて
$$
\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} + \Delta \boldsymbol{\theta}
$$
でパラメータを更新しますね.もちろん,$\eta$ は学習率です.
学習データ(通常はミニバッチ)に対して上記のパラメータ更新を繰り返すわけですが,その繰り返しのステップ $t$ を連続時間と考えて,学習データ から計算される $\boldsymbol{\theta}$ の変化量 $\Delta \boldsymbol{\theta}$ は
$$
\dot{\boldsymbol{\theta}}_{t} = -\frac{\eta}{N} \sum_{n=1}^{N} \{ y_n - f(\boldsymbol{x}_n; \boldsymbol{\theta})\} \frac{\partial}{\partial \boldsymbol{\theta}}f (\boldsymbol{x}_n; \boldsymbol{\theta})
\tag{1}
$$
のように表されます.特に何も難しいことはありません.学習データ上での平均を取っただけです.

さて,ここまではパラメータの時間変化を考えていましたが,関数 $f$ もパラメータが更新されるたびに,つまり $t$ に依存して変化します.と言うことで,$f$の学習における時間変化(ダイナミクスとでも呼びましょうか)について考えてみましょう.$f$ を時間 $t$ で微分すると
$$
\dot{f}(\boldsymbol{x}; \boldsymbol{\theta}_t) = \left( \frac{\partial}{\partial \boldsymbol{\theta}} f(\boldsymbol{x}; \boldsymbol{\theta}_t) \right)^{\mathsf{T}} \dot{\boldsymbol{\theta}}_t
$$
になります. 式 (1) を上式の $\dot{\boldsymbol{\theta}}$ に代入すると
$$
\dot{f}(\boldsymbol{x}; \boldsymbol{\theta}_t )=
-\frac{\eta}{N}\sum_{n=1}^{N} \left( \frac{\partial}{\partial \boldsymbol{\theta}}f(\boldsymbol{x}; \boldsymbol{\theta}_t) \right)^{\mathsf{T}} \frac{\partial}{\partial \boldsymbol{\theta}}f(\boldsymbol{x}_{n} ; \boldsymbol{\theta}_t)\left\{ y_n - f(\boldsymbol{x}_n; \boldsymbol{\theta}_t) \right\}
$$
となります.そして,この式の一部になっている
$$
K(\boldsymbol{x}, \boldsymbol{x}_n) := \left( \frac{\partial}{\partial \boldsymbol{\theta}}f(\boldsymbol{x}; \boldsymbol{\theta}_t) \right)^{\mathsf{T}} \frac{\partial}{\partial \boldsymbol{\theta}}f(\boldsymbol{x}_{n} ; \boldsymbol{\theta}_t)
$$
を Neural Tangent Kernel (NTK) と呼びます.

NTKを考えると何が嬉しいのか?

NTKを使って $f$ の学習中の変化を記述すると
$$
\dot{f}(\boldsymbol{x}; \boldsymbol{\theta}_t ) = -\frac{\eta}{N}\sum_{n=1}^{N} K(\boldsymbol{x}, \boldsymbol{x}_n) \left\{ y_n - f(\boldsymbol{x}_n; \boldsymbol{\theta}_t) \right\}
\tag{2}
$$
となりますね.これは $f$ についての線形微分方程式ですが,$K$ が $\boldsymbol{\theta}$ に依存しているので,このままではとても扱いづらいわけです.
しかし,実は Jacot et al.(2018)や Lee et al. (2019) によって,「ニューラルネットの各層のユニット数が十分に大きいとき,NTK $K$ は任意の時刻 $t$ においてほとんど変化せず,ランダムに定められた重みから作る初期カーネル $K_{0}$を用いて学習のダイナミクスを記述できる」という旨の定理が示されました.これによって式 (2) は線形微分方程式になり,陽に解けます.感動的ですね🥺
しかも,学習の最適解 $\boldsymbol{\theta}^*$ はランダムに定めた初期値 $\boldsymbol{\theta}_{0}$ の近傍に存在するということも示されています(直感的にはなかなか信じがたい結果かもしれませんが).

Random Neural Network の平均場理論

次に, Neural Network の平均場理論 or 神経統計力学と呼ばれているものを紹介します.
こちらもここ数年で注目を浴びている(と個人的には思っている)理論で, Neural Network をマクロ的な視点から(ちょっと語弊ある?)眺めることができる理論です.

平均場理論とはなにか?

めちゃくちゃ大雑把に言うと,「ニューラルネットの個別の入力に対してアレコレ考えるのではなく,もう少しマクロな量に対して色々考えてみましょう」というのが平均場理論(または神経統計力学)です.「何もわからん」って状態ですよね.よくある例を出して説明していきます.

全結合の $L$ 層ニューラルネットワークを考え,第 $l$ 層の $i\, (i=1,\dots, M)$ 番目のユニットが以下のように書けるとします:
$$
h_{i}^l = \phi (u_{i}^{l}),\quad u_{i}^{l}=\sum_{j=1}^{M}W_{ij}^{l}h_{j}^{l-1} + b_{i}^l.
$$
活性化関数が $\phi$ で,中間層のユニット数は全て $M$ で固定しています.また,重み行列 $W$ の各要素とバイアス項 $b$ は次のように初期値を生成します:
$$
W_{ij}^{l} \sim \mathcal{N} \left( 0, \frac{\sigma_{w}^2}{2} \right),
\quad b_{i}^{l} \sim \mathcal{N} \left( 0, \sigma_{b}^{2}\right).
$$
このニューラルネットワークを伝搬する活動の平均値(活動度と呼びましょう)
$$
q_{l} = \frac{1}{N} \sum_{i=1}^{M} \left( u_{i}^{l}\right) ^{2}
$$
を考えてみます.$M$ が十分に大きいとき,
$$
q_{l} = \frac{\sigma_{w}^{2}}{\sqrt{2}} \int \phi \left( \sqrt{q_{l-1}}z \right)^{2} \exp \left( -\frac{z^{2}}{2}\right)\, dz
$$
となりますが,この結果から分かるように $q_{l}$ は個別の結合や入力等のミクロな量に依存せずに $q_{l-1}$ によって定まります.このようにミクロな結合の情報に依存しない変数を巨視的変数と呼び,このような巨視的変数を主眼に研究するのがニューラルネットワークの平均場理論になります(専門の人に刺されそうな説明ですが…).

平均場理論と勾配消失

ここではニューラルネットの平均場理論があると何が嬉しいのかという話をします.
私が大学院生の頃(2016年ごろ)に研究室でニューラルネットを学習させていると,学習が全く進まないという事がありました.そうです.自分で一からニューラルネットを設計した事がある人なら誰もが通るであろう道,勾配消失です.当時の先輩である Shinagawa さんに話したところ「ニューラルネットの気持ちが分からないのか?君自身がニューラルネットになることだ( ・`ω・´)キリッ」という某オシャレ漫画よろしくの大変ありがたいアドバイスを頂戴しました.

さて,そんな勾配消失ですが,どんな時に起こるかを平均場理論で説明できます.
ニューラルネットワークの Back Propagation は Jacobi行列
$$
J_{l} := \frac{dh_{L}}{du_{l}}
=D_{l}W_{l+1}^{\mathsf{T}}D_{l+1}\cdots D_{L-1}W_{L}^{\mathsf{T}},
\quad D:= \mathrm{diag}\left(\phi' (u_{l})\right)
$$
によって構成されますが,その大きさ
$$
\tilde{q}_{l} = \mathrm{trace} \left( J_{l}^{\mathsf{T}} J_{l}\right)
$$
は巨視的変数となります.この $\tilde{q}_{l}$ の挙動に関して,入力層に逆伝搬されてくる量は
$$
\tilde{q}_{1} = \chi^{L-1},\quad
\chi = \sigma_{w}^{2} \int \phi' \left( \sqrt{q_{l-1}}z \right)^{2} \exp \left( -\frac{z^{2}}{2}\right)\, dz
$$
になる事が示せます.これはつまり, $\chi < 1$ だと $\tilde{q}_{1} \to 0$ となって勾配が消失して,$\chi > 1$ だと $\tilde{q}_{1} \to \infty$ となって勾配が発散することを表しています.ということで,ニューラルネットワークの学習を効率的に進めるには,消失も発散もしない相転移点の近傍に ( $\chi = 1$ になるように ) 初期パラメータを設定すれば良いということになります.

おわりに

本記事ではここ数年で注目を浴びている深層学習の理論である Neural Tangent Kernel と平均場理論を紹介しました.あまり難しい数学を使わずに,できるだけ簡単に理解できるように話を限定して紹介したので,これら2つの理論の一部分のみを紹介するという形になってしまいました.それでも,普段ガリガリとコードを書いて動かしている Deep Neural Network の学習の背景には,ここで紹介したような数理的な理論が広がっているんだと言うことを感じ取っていただけたのではと思っています.
私自身,普段の業務でこのような数理的な背景を考えることは少なく,実装して実験して…を繰り返すことの方が多い(最近ではこれも少なくなってしまった)のですが,たまにはふと立ち止まって,このような理論的な勉強をしたいと思います.こう言う理論的な事を知っていると,「なんで上手くいかないの?」と言う時に上手く推察できることも多いですし.

それではみなさま,残り数日の Advent Calendar もお楽しみください!

参考文献

  • Jacot, Arthur, Franck Gabriel, and Clément Hongler. "Neural tangent kernel: Convergence and generalization in neural networks." Advances in neural information processing systems. 2018.
  • Lee, Jaehoon, et al. "Wide neural networks of any depth evolve as linear models under gradient descent." Advances in neural information processing systems. 2019.
  • 甘利 俊一. "新版 情報幾何学の新展開" サイエンス社 2019.
dcm_hiroaki-tanaka
機械学習の理論っぽいことに興味があります.応用も好きです.
nttdocomo-tech
NTTドコモ R&DのOrganaizationアカウントです。 自然言語処理,画像処理,ビッグデータ解析,機械学習,クラウド,IoT,無線通信など幅広いトピックを扱う予定です。
https://www.nttdocomo.co.jp/corporate/technology/rd/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away