導入
行列 $A$,ベクトル $b$ から線形方程式の解$x$を得るという操作ありますよね.($Ax=b$)
まれに深層学習でも用いる瞬間が訪れます.私はPyTorchを使っているのですが,
x = torch.solve(b, A)
で全てがよしなに行われます.そこで疑問に思うのがPyTorchはどういう数式で線形方程式をBackPropagationしてるんでしょう?
線形層や畳み込み層と違い,線形方程式の誤差逆伝搬の式を見たことがないのでここに記そうと思います.
レベルとしてはベクトル微分程度の線形代数と微積を理解している方向けです.
諸定義
文字の定義
本稿で現れる数字と記号の定義は以下です.
$A \in \mathcal{R}^{n\times n}$
$b,x \in \mathcal{R}^{n}$
$e \in \mathcal{R}$
$\cdot^T$ : 行列またはベクトルの転置.
$\cdot^{-1}$ : 正則行列の逆行列.
ロスの定義
最小化するロス$e$は$x$から何らかの関数$L(\cdot)$を通して生成されます.
つまり,$e = L(x)$ です.
そして,線形方程式の誤差逆伝搬の式を導出したいので,ロスから解までの微分が既に計算されているとします.
つまり,given $\frac{\partial e}{\partial x}$です.
今回解く問題とは
誤差逆伝搬の式を導出するというのは,$\frac{\partial e}{\partial A},\frac{\partial e}{\partial b}$の計算式を導出することです.
カンニング
線形方程式の誤差逆伝搬の導出は意外と難しいので先に答えをカンニングしましょう.
PyTorchによると解答は以下です.
\begin{align}
\frac{\partial e}{\partial b} &= (A^{-1})^T \frac{\partial e}{\partial x}\\
\frac{\partial e}{\partial A} &= -\frac{\partial e}{\partial b}x^T
\end{align}
何やら行列の逆伝搬に対してベクトルの逆伝搬の結果を用いるようですね…?
誤差逆伝搬の式の導出
ベクトルbの逆伝搬の導出
まずはベクトル$b$の微分について求めてみましょう.$b$については素直に求めることができます.
\begin{align}
\frac{\partial e}{\partial b} &= \frac{\partial x}{\partial b} \frac{\partial e}{\partial x}\\
&= \frac{\partial (A^{-1}b)}{\partial b} \frac{\partial e}{\partial x}\\
&= (A^{-1})^T \frac{\partial e}{\partial x}
\end{align}
連鎖律を使った誤差逆伝搬法ですね.2→3行目は一般的なベクトル微分の公式を用いました.
これで問題ありませんが,次から行列の微分が出てくるため,ベクトル$b$の逆伝搬をテンソル表記で解くことを試してみます.
\begin{align}
\frac{\partial e}{\partial b_k} &= \frac{\partial x_l}{\partial b_k} \frac{\partial e}{\partial x_l}\\
&= \frac{\partial (A^{-1}b)_l}{\partial b_k} \frac{\partial e}{\partial x_l}\\
&= \frac{\partial ((A^{-1})_{l,m}b_m)}{\partial b_k} \frac{\partial e}{\partial x_l}\\
&= (A^{-1})_{l,m}\frac{\partial b_m}{\partial b_k} \frac{\partial e}{\partial x_l}\\
&= (A^{-1})_{l,m}\delta_{m,k} \frac{\partial e}{\partial x_l}\\
&= (A^{-1})_{l,k} \frac{\partial e}{\partial x_l}\\
&= (A^{-1})^T_{k,l} \frac{\partial e}{\partial x_l}
\end{align}
以上より結果として同じく$\frac{\partial e}{\partial b} = (A^{-1})^T \frac{\partial e}{\partial x}$が得られます.手間が増えてるように見えるかも知れませんが,より階数の高いテンソルに対して考える場合は非常に有効な手段となります.実際行列$A$の逆伝搬をテンソル表記を用いずに行うのは至難の技だと思います.
行列Aの逆伝搬の導出
ではテンソル表記を用いて行列$A$の逆伝搬を考えましょう.
ただ,今回は可読性を上げるため敢えて$W=A^{-1}$と$A$の逆行列に名前を付けておきます.
そうすることで$x=Wb$となるので,$x\rightarrow W\rightarrow A$と二段階に分けて考えることができます.
行列Wに対する逆伝搬
\begin{align}
\frac{\partial e}{\partial W_{i,j}} &= \frac{\partial e}{\partial x_k} \frac{\partial x_k}{\partial W_{i,j}}\\
&= \frac{\partial e}{\partial x_k} \frac{\partial (Wb)_k}{\partial W_{i,j}}\\
&= \frac{\partial e}{\partial x_k} \frac{\partial (W_{k,l}b_{l})}{\partial W_{i,j}}\\
&= \frac{\partial e}{\partial x_k} \frac{\partial W_{k,l}}{\partial W_{i,j}}b_{l}\\
&= \frac{\partial e}{\partial x_k} \delta_{i,k} \delta_{j,l} b_{l}\\
&= \frac{\partial e}{\partial x_i} b_{j}\\
\end{align}
となります.$\frac{\partial e}{\partial W} = \frac{\partial e}{\partial x} b^T$ということですね.
逆行列の微分
$x \rightarrow W$の逆伝搬は明らかになったので,$W \rightarrow A$の伝搬を考えます.
つまり$\frac{\partial W}{\partial A}$を求めます.
W_{k,l}A_{l,m} = I_{k,m}
まず上記は逆行列の定義より明らかです.ここで両辺を$A$で微分すると,
\begin{align}
\frac{\partial}{\partial A_{i,j}}(W_{k,l}A_{l,m}) &= 0_{i,j,k,m}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}}A_{l,m} + W_{k,l}\frac{\partial A_{l,m}}{\partial A_{i,j}} &= 0_{i,j,k,m}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}}A_{l,m} &= - W_{k,l}\frac{\partial A_{l,m}}{\partial A_{i,j}}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}}A_{l,m} &= - W_{k,l}\delta_{i,l}\delta_{j,m}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}}A_{l,m} &= - W_{k,i}\delta_{j,m}
\end{align}
整理され,両辺が$i,j,k,m$の4階のテンソルの等式となっています.
ここでさらに両辺に$W_{m,l}$を右から加えることで,$i,j,k,l$のテンソルの式に変えます.
\begin{align}
\frac{\partial W_{k,l}}{\partial A_{i,j}}A_{l,m}W_{m,l} &= - W_{k,i}\delta_{j,m}W_{m,l}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}}I_{l,l} &= - W_{k,i}\delta_{j,m}W_{m,l}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}} &= - W_{k,i}W_{j,l}\\
\end{align}
以上で$W,A$間の関係を導けました.なお得られた結果はベクトルや行列の表記で表すことは困難です.
しかし解釈は用意で$i,j,k,l=1,2,3,4$とした例を挙げると,$\frac{\partial W_{3,4}}{\partial A_{1,2}} = - W_{3,1}W_{2,4}$となります.つまり$A$の逆行列($W$)の3行目4列目の成分を$A$の1行2列目の成分で偏微分した結果は逆行列の3行目1列目の要素と2行目4列目の要素の積のマイナスという簡単な表現で表せるということです.
Aの逆伝搬の解
\frac{\partial e}{\partial W_{k,l}}= \frac{\partial e}{\partial x_k} b_{l}\\
\frac{\partial W_{k,l}}{\partial A_{i,j}} = - W_{k,i}W_{j,l}
既に得られた以上の2式から$\frac{\partial e}{\partial A_{i,j}}$を得ましょう.
\begin{align}
\frac{\partial e}{\partial A_{i,j}} &= \frac{\partial W_{k,l}}{\partial A_{i,j}}\frac{\partial e}{\partial W_{k,l}}\\
&= - W_{k,i}W_{j,l}\frac{\partial e}{\partial x_k} b_{l}
\end{align}
はい.とても単純で拍子抜けですね.テンソル表記なら完全に終わりですが,このままだとベクトル・行列はできないのでちょっと工夫してみましょう.
まず,$W_{j,l}b_{l}$の部分のみは$Wb$と表記できます.これはつまり$x$なので以下に書き換えることができます.
\begin{align}
\frac{\partial e}{\partial A_{i,j}}
= - W_{k,i}\frac{\partial e}{\partial x_k} x_{j}
\end{align}
また,$W=A^{-1}$であることを考えるとベクトル$b$の逆伝搬の式が適用できます.
\begin{align}
\frac{\partial e}{\partial A_{i,j}}
= - \frac{\partial e}{\partial b_i} x_{j}
\end{align}
これのベクトル・行列表記は$\frac{\partial e}{\partial A} = - \frac{\partial e}{\partial b} x^T$となります.全ての導出が終わりました.お疲れ様です!
間違った導出
最初私もハマってしまったのですが,同様のミスを他の方もしてしまう可能性があるのでおまけとして記述しておきます.
行列$A$に対する逆伝搬を求める際に$e \rightarrow x \rightarrow A$ではなく,$e \rightarrow b \rightarrow A$として求めてしまうと大間違いします.
\begin{align}
\frac{\partial e}{\partial A} &= \frac{\partial b}{\partial A} \frac{\partial e}{\partial b}\\
&=\frac{\partial (Ax)}{\partial A} \frac{\partial e}{\partial b}\\
&=x \frac{\partial e}{\partial b}\\
\end{align}
間違いは2点あり,1つは前述の$e \rightarrow b \rightarrow A$という経路は逆伝搬として間違っています.
2つ目は$\frac{\partial (Ax)}{\partial A}=x$という式は合ってそうで合っていません.これは行列によるベクトルの微分はベクトル・行列表記をそもそもすることができません.
実際の答えと似ているためとても紛らわしいですね.
また$Ax=b$という関係から,$x$を小さくするなら$b$は小さく,$A$は大きくすべきですよね.正しい答えには$A$の偏微分に対してマイナスがついているが,こちらにはついていないことも間違いであることを支持しますね.