4
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

はじめてのNeural Tangent Kernel(NTK)

Posted at

2018年に出て盛り上がったNeural Tangent Kernelの論文をわかりやすく解説してくれた授業がYouTubeにアップされていたので、それをまとめてみたいと思う。

元論文
Neural Tangent Kernel: Convergence and Generalization in Neural Networks (NIPS 2018)

細かい解説記事はこちら
Understanding the Neural Tangent Kernel
日本語版
Neural Tangent Kernel(NTK)の概要

読むより授業の方が自分には合っているので、忘れないようにメモ。

解きたい問題: 入力 $x_i \in \mathbb{R}^d$ に対して、出力 $y_i \in \mathbb{R}$ を知りたい。

linear regressionで解く場合は、linear model $f(W, x) = W^T x = \hat{y}$ を使い、以下のLoss関数 $\mathcal{L}$ を小さくする。

$\mathcal{L}(W) = \dfrac{1}{2} \sum_{i=1}^n (y_i - f(W, x_i))^2$

ここで、$x$ は $d$ 次元なので、$f$ は $d$ 次元にしかなりえない。
これだと小さくて嫌なので、これをnon-linearityを使って $D$ 次元まで大きくしてみたい。
例えば、 $x=(a_1, a_2, a_3)$ の3次元だったのを、$\phi(x) = (a_1, a_2, a_3, a_1a_2, a_1a_3, a_2a_3)$ とかにすると $D=6$ 次元になる。
こうすると、 $f(W, x) = W^T \phi(x)$ で次元が増え、

  • $W$に対してはlinear
  • $x$に対してはnon-linear

これを、そのまま、上のLoss関数に使うと次元が増えて良いけど、 $\phi(x)$ がいつも同じnon-linearityで、$D >> d$の場合、計算量が大変なことになる。
例えば、$\phi$ がk次元のpolynomialだと $D=\mathcal{O}(d^k)$ で大きくなる。
ということで、出てくるのが一世を風靡したKernel Trick!

Kernel Trick の復習
Kernel Function: $k(x_i, x_j) = <\phi(x_i), \phi(x_j)>$
Kernel Matrix: $\mathcal{K} \in \mathbb{R}^{n \times n}$ の $(i,j)$ elementは、kernel functionの $k(x_i, x_j)$ で、positive semi-definite matrix(PSD)

$\phi(x_i) \in \mathbb{R}^D$ なので、$k(x_i, x_j)$ の計算も $\mathcal{O}(D)$ で大きくなるが、 $\phi(x_i)$ によっては、簡単に計算する方法がある。
例えば、k-polynomial kernelの場合、$k(x_i, x_j) = (c+x_i^T x_j)^k$ で計算が出来て $\mathcal{O}(d)$ で済む。
あまりにもお手軽なので、いろいろなkernelが考え出され、一気にいろいろな方法に拡張されて

  • SVM は kernel SVM に
  • ridge regression は kernel ridge regression に
  • PCA は kernel PCA に

なった。DNNが来る前は、kernel methodの時代で、いっぱい解説がある。

Neural Network に当てはめてみる
non-linearityを入れた $\phi(x)$ は最初にパラメータを設定してから変わらないのが嫌なので、学習で決められるようにする。
これは、Hidden Layerが一層だけのネットワークと同じで $f(W, x)$ の $W$ は全てのネットワークWeightとbiasということになる。
Loss関数も変わらず、

$\mathcal{L}(W) = \dfrac{1}{2} \sum_{i=1}^n (y_i - f(W, x_i))^2$

これをgradient descent法で解くと

$W(t+1) = W(t) + \eta_t \sum^n_{i=1}(y_i-f(W(t), x_i))\nabla_Wf(W(t), x_i)$

$\eta_t$ は、learning rate。
$\nabla_Wf$は、linear regressionのときは固定だったけど、今回は、$W(t)$ に合わせて変わる。
じゃあ、どんだけ変わるんだってのを実験して見てみると、案外、初期値 $\mathcal{N}(0, 1)$ から変わらないという。
いつもってわけではないけど、ネットワークのWidthが大きいとほとんど"static"で、学習は"Lazy"。実際どんだけ変わっているかは、上にリンクを貼った解説記事に出ている。ニューロンが多いと(widthが大きいと)見た目、本当に変わらない。
parameter spaceで考えるとあるボールの中に残っているか出て行ってしまうか
image.png
初期値から対して変わらんってなったら、次は初期値 $W_0$ あたりでTaylor approximationしてみたくなる。

$f(W, x) \simeq f(W_0, x) + \nabla_W f(W_0, x)^T (W-W_0) + \cdots$

これで考えるためにはHessianがboundedである必要がある(Hessian control)。ということで、この頃出ているover parameterizationの解析につながる。
最初の2つの項は、$W$ に対してはlinearで、$x$ に対してはnon-linear。
ここで、model $f$ をニューラルネットに合わせて、$W$ が、バイアス $b$ と重み $a$ 、activation関数 $\delta$ でニューロン数 $m$ のhidden layer 1層とすると

$f_m(W, x) = \dfrac{1}{\sqrt{m}} \sum_{i=1}^m b_i \ \delta(a_i^T x)$

gradientを取ると

  • $\nabla_{a_i} f_m(W, x) = \dfrac{1}{\sqrt{m}} \sum_{i=1}^m b_i \ \delta'(a_i^T x) \ x$
  • $\nabla_{b_i} f_m(W, x) = \dfrac{1}{\sqrt{m}} \sum_{i=1}^m \delta(a_i^T x)$

Neural Tangent Kernel (NTK)

ここで、Kernel trickを思い出す。
まず、 neural tangent kernel $\phi(x) \triangleq \nabla_W f(W_0, x)$ と定義。kernel function は $k(x_i, x_j) = <\phi(x_i), \phi(x_j)>$。
$W$ を表す $a$ と $b$ に分割して

$k_m(x, x') = k_m^{(a)}(x, x') + k_m^{(b)}(x, x')$

  • $k_m^{(a)}(x, x') = \dfrac{1}{m} \sum_{i=1}^m b_i^2 \ \delta'(a_i^T x) \ \delta'(a_i^T x') \ (x \cdot x')$
  • $k_m^{(b)}(x, x') = \dfrac{1}{m} \sum_{i=1}^m \delta(a_i^T x) \ \delta(a_i^T x')$

ニューロン数 $m$ の平均になっているので、$m$ を無限値に飛ばしてみると

  • $k_m^{(a)}(x, x') \xrightarrow{m \rightarrow \infty} k^{(a)}(x, x') = \mathbb{E}[\ b^2 \ \delta'(a^T x) \ \delta'(a^T x') \ (x \cdot x') \ ]$
  • $k_m^{(b)}(x, x') \xrightarrow{m \rightarrow \infty} k^{(b)}(x, x') = \mathbb{E}[\ \delta(a^T x) \ \delta(a^T x') \ ]$

さらにここで、activation function $\delta$ をReLUとして、$a_i$ のdistributionをrotation invaricant in $\mathbb{R}^d$ とすると

  • $k^{(a)}(x, x') = \dfrac{(x \cdot x') \ \mathbb{E}[\ b^2 ]}{2\pi}(\pi-\theta(x, x'))$
  • $k^{(b)}(x, x') = \dfrac{||x|| \ ||x'|| \ \mathbb{E}[\ ||a||^2 ]}{2\pi d}((\pi-\theta(x, x'))\cos(\theta)+\sin(\theta))$

と書け、$\theta(x, x')$ は、$x$, $x' \in [0, \pi]$ の角度。

Gradient flow

上で書いたgradient descent法を $y \in \mathbb{R}^n$, $f(W, x) = \hat{y} \in \mathbb{R}^n$ のベクターに置き換えると

$W(t+1) = W(t) + \eta_t \sum^n_{i=1}(y_i-f(W, x_i))\nabla_Wf(W_t, x_i) = W(t) + \eta_t (y-\hat{y})\nabla_W \hat{y}$

でlearning rate $\eta$ が0に近づくときは、普通のgradient flowと同じように

$\dfrac{W(t+1) - W(t)}{\eta_t} = (y-\hat{y})\nabla_W \hat{y} \ \xrightarrow{\eta \rightarrow 0} \ \dfrac{dW(t)}{dt}= - (\hat{y}-y)\nabla_W \hat{y}$

となり、$W$ のdynamicsがどう変わるかを表せる。

$\hat{y}$ の変化をchain ruleを使って表す際に、gradient flowを使うと

$\dfrac{d\hat{y}(W(t))}{dt} = \nabla_W\hat{y}(W(t))^T \ \dfrac{dW(t)}{dt} = - \nabla_W\hat{y}(W(t))^T \ \nabla_W\hat{y}(W(t)) \ (\hat{y}(W)-y)$

で、あら不思議、kernel functionが出てきた。
ここで、widthが大きい場合、$W$ は初期値 $W_0$ からあんまり変わらないってのからneural tangent kernel matrixを使うと、

$\dfrac{d\hat{y}(W(t))}{dt} \thickapprox - \mathcal{K}(W_0)(\hat{y}(W)-y)$

となる。ここで、$u = \hat{y}(W)-y$ とすると、

$\dfrac{du}{dt} \thickapprox - \mathcal{K}(W_0) u$

で、どっかで見たことのある常微分方程式が出てきて、$u(t) = u(0) e^{-\mathcal{K}(W_0)t}$ と解が表せる。

この解析は、over-parameterizedの状態なので、$\mathcal{K}(W_0) > 0$ になり、positive definite(semiがつかない)でeigenvalue $0<\lambda_1 < \lambda_2 < \cdots < \lambda_n$ とそれに付随するeigenvector $v_i$ を使い

$\mathcal{K}(W_0) = \sum^n_{i=1} \lambda_i \ v_i \ v_i^T$

と表すと

$u(t) = u(0) \prod_{i=1}^n e^{-\lambda_i \ v_i \ v_i^T}$

となる。$\lambda$ は、収束率を表す。一番小さいeigenvalueが遅いときを表して、一番重要。

ということで、良さそうだけど、SOTAのネットワークの方が未だに結果としては良いため、あまり実用上は使われていない。とはいえ、この頃のbayesian系の流れから、まだまだ改良されていくかもしれない。

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?