LoginSignup
2
1

Vector Jacobian Product (VJP) と Jacobian Vector Product (JVP)

Posted at

概要

Atılım Güneş Baydin, Barak A. Pearlmutter, Don Syme, Frank Wood, Philip Torr, "Gradients without Backpropagation", https://arxiv.org/abs/2202.08587

という論文を読みました。バックプロパゲーションを使わずに最適化のための目的関数の勾配を計算していて興味深かったです。論文中で出てくるvector–Jacobian product (VJP)と Jacobian–vector product (JVP) を知らなかったので、本記事でまとめたいと思います。

vector–Jacobian product (VJP)

2層のニューラルネットワークの例を考えます。

\begin{align}
& z_j^{(0)} = x_j, \\
& z_j^{(\ell)} \left( \boldsymbol{w}^{(\ell)}, \boldsymbol{z} \right) = h \left( a_j^{(\ell)} \right) = h(\sum_{i=1} w_{ji}^{(\ell)}z_i^{(\ell-1)} + b_j^{(\ell)} ), \quad (\ell=1), \\
& y_k(\boldsymbol{x}, \boldsymbol{w}^{(1)}) = \sigma\left( \sum_{j=1} w_{kj}^{(2)} z_j^{(1)} + b_k^{(2)} \right), \\
& \min_\boldsymbol{w}E(\boldsymbol{y}, \boldsymbol{t}).
\end{align}

ここで、$\boldsymbol{x}$が入力、$\boldsymbol{w}$が重み、$\boldsymbol{b}$がバイアス、$h$は中間層の活性化関数、$\sigma$は最終層の正規化関数、$\boldsymbol{t}$が教師ラベル、$E$が目的関数です。勾配はバックプロパゲーションで求められたのでした。

\begin{align}
\frac{\partial E}{\partial w_{kl}^{(1)}}
&= \sum_{ij} \frac{\partial E}{\partial y_i} \frac{\partial y_i}{\partial z_j^{(1)}} \frac{\partial z_j^{(1)}}{\partial w_{kl}^{(1)}} \\
&\equiv \sum_{ij}  v_i J_{ij} J_{j,kl}^{(w)}, \\
\frac{\partial E}{\partial x_{k}}
&= \sum_{ij} \frac{\partial E}{\partial y_i} \frac{\partial y_i}{\partial z_j^{(1)}} \frac{\partial z_j^{(1)}}{\partial x_{k}} \\
&\equiv \sum_{ij}  v_i J_{ij} J_{jk}^{(x)}.
\end{align}

式中に$v_i J_{ij}$がありますが、これがVJPです。関数$f: \mathbb{R}^n \rightarrow \mathbb{R}^m$に対してバックプロパゲーションを実行する際に現れるのが$\sum_{i}v_i J_{ij}$で$J\in \mathbb{R}^{m\times n}$はヤコビ行列、$\boldsymbol{v}\in \mathbb{R}^m$は後ろの層から伝わってきた誤差ベクトルです。層が深ければ、重みを$v_i J^{(2)}_{ij}J^{(1)}_{jk}=\tilde v_j J^{(1)}_{jk}$のようにチェインして、VJPとして計算することができます。したがって各層のヤコビアンの形をあらかじめ求めておけばバックプロパゲーションを計算できます。

コードでも確認してみます。線形層のクラスを見てみましょう。

main.py
import numpy as np

class Linear:
    def __init__(self, W, b):
        self.W = W
        self.b = b
        self.x = None
        self.dW = None
        self.db = None

    def forward(self, x):
        self.x = x
        out = np.dot(x, self.W) + self.b

    def backward(self, v):
        dx = np.dot(v, self.W.T)  # xに関するヤコビアンのVJP
        self.dW = np.dot(self.x.T, v)  # Wに関するヤコビアンのVJP
        self.db = np.sum(v, axis=0)
        return dx

線形層だとヤコビアンがつまらない形ですが、伝搬してきたベクトルvにヤコビアンをかけてdWdxを計算しています。

Jacobian Vector Product (JVP)

関数$f: \mathbb{R}^n \rightarrow \mathbb{R}^m$に対して$\sum_{j} J_{ij}v_j$がJVPです。$J\in \mathbb{R}^{m\times n}$はヤコビ行列、$\boldsymbol{v}\in \mathbb{R}^n$は方向ベクトルです。$\boldsymbol{v}$の次元がVJPと異なり、$f$の引数の次元であることに注意してください。

JVPは方向微分になります。

\begin{align}
\mathrm{JVP}(\boldsymbol{f}, \boldsymbol{v}) = J \boldsymbol{v} =\sum_i \frac{\partial \boldsymbol{f}}{\partial x_i} v_i.
\end{align}

合成関数$\boldsymbol{f}(\boldsymbol{x}) = \boldsymbol{f}^{(2)} (\boldsymbol{f}^{(1)} (\boldsymbol{x}))$に対するJVPを計算してみると、

\begin{align}
\mathrm{JVP}(\boldsymbol{f}, \boldsymbol{v}) 
&= \sum_i \frac{\partial \boldsymbol{f}}{\partial x_j} v_i \\
&= \sum_{i,j} \frac{\partial \boldsymbol{f}^{(2)}}{\partial f_j^{(1)}}\frac{{\partial f_j^{(1)}}}{\partial x_i} v_i \\
&= \mathrm{JVP}(\boldsymbol{f}^{(2)}, \mathrm{JVP}(\boldsymbol{f}^{(1)}, \boldsymbol{v}) ) 
\end{align}

となることが分かります。したがって、JVPの計算ではバックプロパゲーション同様にチェインルールが使えます。

ニューラルネットワークの場合は次のようになります。

\begin{align}
\sum_k \frac{\partial E}{\partial x_{k}} v_k
&= \sum_k  \sum_{ij} \frac{\partial E}{\partial y_i} \frac{\partial y_i}{\partial z_j^{(1)}}  \frac{\partial z_j^{(1)}}{\partial x_{k}} v_k \\
&\equiv \sum_{ijk}  J^{(E)}_i J_{ij} J_{jk}^{(x)} v_k.
\end{align}

したがって各層のヤコビアンの形をあらかじめ求めておけば、フォワードで積をとってJVPを計算できます。

コードでも確認してみます。線形層のクラスを見てみましょう。

main.py
import numpy as np

class Linear:
    def __init__(self, W, b):
        self.W = W
        self.b = b

    def forward(self, x, v):
        self.x = x
        out = np.dot(x, self.W) + self.b
        jvp_val = self.forward_jvp(v)
        return out, jvp_val

    def forward_jvp(self, v):
        jvp_val = self.W.T @ v
        return jvp_val

JVPはforwardで計算されます。方向ベクトルvは事前に決めておく必要があります。ちなみに、JAXやPyTorchは各層にJVPの計算メソッドを持っている訳ではなく、jax.jvptorch.autograd.functional.jvpという関数を使用して、予め定義されたプリミティブな関数に対する微分をチェインルールで実行していくようです。

JAXの自動微分の解説

JAXのJVP実装の解説

JAXのJVP実装のソースコード

JAXのプリミティブ関数のJVP実装のソースコード

先程の層への実装とtorch.autograd.functional.jvpの実行結果を比べてみます。

main.py
import numpy as np
import torch
import torch.nn as nn
from torch.autograd.functional import jvp


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(2, 3)
        self.layer_2 = nn.Linear(3, 3)

    def forward(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        return x


x = torch.Tensor([1, 2])
v = torch.Tensor([1, 1])
net = Net()
func_output, jvp_val = jvp(net, x, v)
print(func_output)  # tensor([0.3299, 0.3976, 0.3763])
print(jvp_val)  # tensor([0.0474, 0.0245, 0.0107])


x = np.array([1, 2])
v = np.array([1, 1])
layer1 = Linear(
    net.layer_1.state_dict()['weight'].T.clone().numpy(),
    net.layer_1.state_dict()['bias'].clone().numpy()
)
layer2 = Linear(
    net.layer_2.state_dict()['weight'].T.clone().numpy(),
    net.layer_2.state_dict()['bias'].clone().numpy()
)

out, jvp_val = layer1.forward(x, v)
out, jvp_val = layer2.forward(out, jvp_val)
print(out)  # [0.32985815 0.39755509 0.37627642]
print(jvp_val)  # [0.04740545 0.02448669 0.01071939]

同じ計算結果になっています。

バックプロパゲーションを使わない最適化

ニューラルネットワークのパラメータ更新は、以下の式で計算されます。

\begin{align}
\boldsymbol{w} \leftarrow \boldsymbol{w} - \eta \nabla E
\end{align}

パラメータ更新のためには勾配$\partial E/\partial \boldsymbol{w}$が必要で、通常はバックプロパゲーションで計算されます。

論文では、方向微分と方向ベクトルの積が勾配の一致推定量になることを利用しています。重みと同じ次元の方向ベクトルをガウス分布からサンプリングし、forwardを計算、JVPを求め平均を求め、勾配を計算しています。この勾配を用いて勾配降下法を行います。

\begin{align}
\mathbb{E}_{\boldsymbol{v}}\left[(\nabla E \cdot \boldsymbol{v}) \boldsymbol{v} \right]=\nabla E, \quad \boldsymbol{v} \sim \mathcal{N}(\boldsymbol{0}, I)
\end{align}

forward AD アルゴリズム

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1