この記事はFOLIO Advent Calendar 2018、21日目の記事です。昨日はhajipionによる大規模デザインカンファレンス「デザインシップ」の着想から開会まで、超ダッシュで振り返る10ヶ月! #Designship2018でした。
この記事は近年話題に上がることが多くなってきた自動微分を実際に実装することでより理解を深めようというものです。前回の自動微分を実装して理解する(前編)では自動微分の定義とフォワードモードの実装を見てきました。今回はリバースモードの自動微分の実装にチャレンジしてみましょう。
リバースモード
現実世界(特に機械学習)の関数は一般的に出力の次元より入力の次元のほうが圧倒的に多いので、フォワードモードよりリバースモードの方が効率面で優れている場面が多々あります。しかし一般的なリバースモードの実装方法では数式を組み立てた順番を計算グラフとして保持する必要があるため、その実装はフォワードモードの時よりも複雑になってしまいます。前回と同様にソースコード変換1による実装や演算子のオーバーロードによる実装2ももちろん可能ですが、ここでは最近話題になった34シンプルなリバースモードの実装を紹介したいと思います。
双方向の計算による実装
リバースモードを使った実装が複雑になるのは、数式を組み立てたあとに逆順に導関数を計算するという双方向の計算を実装する必要があるからです。そこで数式を組み立てるのと同時に逆向きに導関数を組み立てることでリバースモードの自動微分を実現することを考えます。
この方法では順方向に値を計算する関数x -> y
と、x
とy
におけるこれから合成される関数の微分係数dfdy
を受け取ってx
における関数全体の微分係数dx
を返す関数x -> dfdy -> dfdx
を考えます(dfdy
は$\frac{\partial f}{\partial y}$、dfdx
は$\frac{\partial f}{\partial x}$のつもりです)。つまり
(x -> y, x -> dfdy -> dfdx)
という型を考えます。これを使っていくつかの関数を実装してみましょう。
pow' :: Int -> (Double -> Double, Double -> Double -> Double)
pow' n = (\x -> x^n, \x dfdy -> dfdy * n * x^(n-1))
sin' :: (Double -> Double, Double -> Double -> Double)
sin' = (sin, \x dfdy -> dfdy * cos x)
exp' :: (Double -> Double, Double -> Double -> Double)
exp' = (exp, \x dfdy -> dfdy * exp x)
実装したのはべき乗と三角関数$\sin$と指数関数です。例えばこれらの関数をうまく合成して、
f = e^{\sin x^2}
という数式を
\begin{matrix}
f &=& e^w \\
w &=& \sin u \\
u &=& x^2 \\
\end{matrix}
という風に組み立てることができたとします。すると導関数の計算は、まず最初にexp
の導関数部分が計算されるでしょう。引数にはx
と、dfdy
として
\frac{\partial f}{\partial f} = 1
が代入され、計算結果はexp x
となります。
次にsin
の導関数部分が計算されます。引数にはx
と、dfdy
として先程計算した
\frac{\partial f}{\partial f}\frac{\partial f}{\partial w} = e^x
が代入されexp x * cos x
が計算されます。
最後にpow
の導関数部分が計算されます。引数にはx
と、dfdy
として先程計算した
\left(\frac{\partial f}{\partial f}\frac{\partial f}{\partial w}\right)\frac{\partial w}{\partial u} = e^x\cos x
が代入されexp x * cos x * n * x^(n-1)
が計算されます。こうして無事に導関数を計算することができました。この計算順序を見てみると、
\left(\left(\frac{\partial f}{\partial f}\frac{\partial f}{\partial w}\right)\frac{\partial w}{\partial u}\right)\frac{\partial u}{\partial x}
という順番になっています。実はフォワードモードの自動微分は
\frac{\partial f}{\partial f}\left(\frac{\partial f}{\partial w}\left(\frac{\partial w}{\partial u}\frac{\partial u}{\partial x}\right)\right)
という順番で計算する手法だったので、リバースモードは導関数の計算順序が逆になっているのがわかると思います。
それでは本当に計算をこのような順番となるようにうまく組み立てることができるのでしょうか?今考えている型をもう一度見てみましょう
(x -> y, x -> dfdy -> dfdx)
この型、どこかで見覚えがありませんか?そう、1つ目の要素をGetter、2つ目の要素をSetterだと思えばまさにLensの型ですね!ということで自動微分の型を van Laarhoven Lens として表してみましょう。
type Lens s t a b = forall f. Functor f => (a -> f b) -> s -> f t
lens :: (s -> a) -> (s -> b -> t) -> Lens s t a b
lens sa sbt afb s = sbt s <$> afb (sa s)
type AD x dfdx y dfdy = Lens x dfdx y dfdy
type AD' x dfdx = AD x dfdx x dfdx
Lens
とlens
の定義はmicrolensライブラリから拝借してきました。これを使えば先程の関数は
pow' :: Int -> AD' Double Double
pow' n = lens (\x -> x^n) (\x dfdy -> dfdy * (fromIntegral n) * x^(n-1))
sin' :: AD' Double Double
sin' = lens sin (\x dfdy -> dfdy * cos x)
exp' :: AD' Double Double
exp' = lens exp (\x dfdy -> dfdy * exp x)
のように実装することができます。Lensになったおかげで関数の合成には通常の合成関数の演算子.
が使えるので、
f = e^{\sin x^2}
という数式は
f :: AD' Double Double
f = pow' 2 . sin' . exp'
のように実装することができます。
このAD
から関数の値を取り出す関数にはLensとして値を取り出す関数view
を使うことができます。
> view f 0
1.0
> view f (sqrt $ pi / 2)
2.718281828459045
これは
e^{\sin 0} = e^0 = 1
そして
e^{\sin \left(\sqrt{\frac{\pi}{2}}\right)^2} = e^{\sin\frac{\pi}{2}} = e
なので合っていますね。
一方、導関数の値を取り出す関数はdfdy
として$1$をセットすればいいので
grad :: Num b => AD' a b -> a -> b
grad l = set l 1
と実装することができます。
> grad f (sqrt $ pi)
-3.544907701811034
これは
\frac{\partial f}{\partial x} = 2x\cos\left(x^2\right)e^{\sin x^2}
であり
\begin{matrix}
2{\sqrt \pi}\cos\left(\left({\sqrt \pi}\right)^2\right)e^{\sin \left({\sqrt \pi}\right)^2} &=& 2{\sqrt \pi}\cos\pi e^{\sin \pi} \\
&=& 2{\sqrt \pi}\times(-1) \times e^0 \\
&=& -2{\sqrt \pi} \\
&=& -3.544\dots
\end{matrix}
となりちゃんと計算できていることが分かりました
このようにAD
をLensとして表現することで関数合成の演算子.
を使って数式を構築し、リバースモードの自動微分が実装できることが分かりました。ただこの方法は値ではなく関数を拡張しているため、自動微分したい関数をポイントフリースタイルで書かないといけないのが欠点です。文献[4]4にはこの方法を利用してニューラルネットワークを実装した例も紹介されています。
ADはLensなのか
ところでLensには満たすべき法則として以下のようなLens則5がありました。
> view l (set l v s) == v
> set l (view l s) s == s
> set l v' (set l v s) == set l v' s
しかし先程実装したf
で実験してみると
> view f (set f 5 1)
1.0168416238956406
> set f (view f 1) 1
5.815127313951733
そして
> set f 1 (set f 1 1)
5.016872534352483
> set f 1 1
2.506761534986894
となりAD
はLens則を満たしていません。
(x -> y, x -> dfdy -> dfdx)
元の型を思い出してみるとLensのGetterに対応する部分は関数、Setterに対応する部分は導関数であり、何か値を取り出したり詰めたりするための参照のような性質は持っていないので、Lensとしての性質を持たないのは当たり前のような気もします。合成関数の演算子.
を使うためにLensという概念を形式的に導入しましたが、AD
にLensが持つような抽象的な構造や性質は期待しないほうが良さそうです。
追記: 2021/12/15
Lens則は必ずしもLensが満たすべき性質とは思われてないようです。"Categories of Optics"によるとLens則を満たすLensは 一般的なLensと区別してlawfulなものと呼ばれています。
この自動微分をLensとして構築する考え方を応用した話を以下の記事に書いたので気になる人は参考にしてください。
探求: 使い勝手の良いリバースモードの実装
(ここからは私独自の試みです)
双方向の計算によるリバースモードの実装は数式の実装をポイントフリースタイルで書かなければいけないため数式の記述が煩雑になってしまいました。どうにか演算子のオーバーロードの時のように元のプログラミング言語の演算子を使って数式を組み立てるようにはできないでしょうか。ここでは今までの知識を総動員して
- EDSLから変換する先として
AD
を利用する - 1で実装したEDSLに対して演算子のオーバーロードを使う
という方針で使い勝手の良いリバースモードの実装を検討していきたいと思います。
EDSLから変換する先としてAD
を利用する
フォワードモードをソースコード変換で実装した時にEDSLをEDSLに変換するという手法を用いました。ここではEDSLをEDSLに変換するのではなく、値の組み合わせによって関数の組み合わせを表現することで、EDSLをAD
に変換することを考えます。さっそくEDSLを定義してみましょう。
data Expr = X -- 変数
| Lit Double -- 定数
| Neg Expr -- -x
| Abs Expr -- |x|
| Sig Expr -- sign x
| Add Expr Expr -- x + y
| Mul Expr Expr -- x * y
deriving Show
後々でNum
のインスタンスにすることを念頭にそれぞれの値コンストラクタを用意しています。このEDSLをAD
に変換する前にAD
の関数をいくつか用意しておきましょう。
pair' :: Lens' a b -> Lens' c d -> Lens' (a, c) (b, d)
pair' l1 l2 = lens (\(a, c) -> (view l1 a, view l2 c))
(\(a, c) (b, d) -> (set l1 b a, set l2 d c))
dup' :: Num a => Lens' a (a,a)
dup' = lens (\x -> (x, x)) (\x (dfdx, dfdy) -> dfdx + dfdy)
lit' :: Num a => a -> Lens' a a
lit' x = lens (\_ -> x) (\_ _ -> 0)
neg' :: Num a => Lens' a a
neg' = lens (\x -> -x) (\_ dfdy -> -dfdy)
abs' :: Num a => Lens' a a
abs' = lens (\x -> abs x) (\x dfdy -> dfdy * signum x)
sig' :: Num a => Lens' a a
sig' = lens (\x -> signum x) (\_ _ -> 0)
add' :: Num a => Lens' (a,a) a
add' = lens (\(x,y) -> x + y) (\_ dfdy -> (dfdy, dfdy))
mul' :: Num a => Lens' (a,a) a
mul' = lens (\(x,y) -> x * y) (\(x, y) dfdz -> (dfdz * y, x * dfdz))
pair'
は実装より型を見ると何をしたいか分かりやすいと思います。dup'
は入力を2つに分岐する関数で、このあとの実装を見ると必要性がわかると思います。あとは定数を表現するlit'
とNum
のメソッドに対応するLensたちです。これらのLensを使ってEDSLからAD
への変換を実装してみましょう。
eval :: (Expr -> Expr) -> AD' Double Double
eval f = eval' (f X)
where
eval' :: Expr -> AD' Double Double
eval' X = id
eval' (Lit x) = lit' x
eval' (Neg e) = eval' e . neg'
eval' (Abs e) = eval' e . abs'
eval' (Sig e) = eval' e . sig'
eval' (Add e1 e2) = dup' . pair' (eval' e1) (eval' e2) . add'
eval' (Mul e1 e2) = dup' . pair' (eval' e1) (eval' e2) . mul'
X
が基底部となるようにAD
への変換がうまく再帰的に書けているのがわかるかと思います。特にAdd
やMul
の変換がdup'
とpair'
を使って綺麗に配線できました。
EDSLに対して演算子のオーバーロードを使う
それでは作ったExpr
をNum
のインスタンスにしてみましょう。
instance Num Expr where
e1 + e2 = e1 `Add` e2
e1 * e2 = e1 `Mul` e2
negate e = Neg e
abs e = Abs e
signum e = Sig e
fromInteger n = Lit (fromInteger n)
難しいことは無く、単純に用意したEDSLに変換しているだけですね。以上の準備で演算子のオーバーロードを使ったリバースモードの自動微分を実装することができました。実際に使ってみましょう
> f = eval $ \x -> 3 * x * x + x * 1
> :type f
f :: Functor f => (Double -> f Double) -> Double -> f Double
> f' = grad f
> f' 1
7.0
> f' 2
13.0
> f' 3
19.0
f
として
f = 3x^2 + x + 1
という関数を想定しています。f
の実装にはHaskell標準の演算子が使われており、ポイントフリースタイルの制約は取り払われています。$f$の導関数は
\frac{df}{dx} = 6x + 1
となり
\begin{matrix}
\frac{df}{dx}(1) &=& 7 \\
\frac{df}{dx}(2) &=& 13 \\
\frac{df}{dx}(3) &=& 19
\end{matrix}
となるのでうまく実装できてそうですね
以上のように、演算子のオーバーロード・ソースコード変換・双方向の計算による実装を全て駆使して使い勝手の良いリバースモードの自動微分(1変数ですが)を実装することができました。このあとの課題として、型族を使って多変数関数に拡張したり、Expression Problemを解決しながらより多くの数値的な型クラスのインスタンスを作ったり、行列やベクトルに対応することなどが考えられますが、体力が尽きたので今回の実験はここまでにしておきます。
この記事を通して少しでも皆さんの自動微分への理解に貢献できれば幸いです。
追記: 2024/12/02
EDSLを使ったリバースモード自動微分について、Haskellのライブラリ ad, backprop で実際に使用されている より効率的な実装方法をだめぽ氏が解説してくれた記事があるのでそちらも参考にしてください!
-
Haskellの有名な自動微分ライブラリであるadやbackpropは演算子のオーバーロードを使ってリバースモードの自動微分を実装しています ↩
-
[1804.00746] The simple essence of automatic differentiation ↩
-
Reverse Mode Differentiation is Kind of Like a Lens II - Hey There Buddo! ↩ ↩2