計算グラフにおける 分岐ノードの誤差逆伝播は上流からの勾配の和になる ということが、自明のような感じで語られることが多いです。私にとっては、これがここ数年来の疑問でした。今回はこの疑問に対する考察をまとめたいと思います。
分岐ノードとは、流れが2つに分岐するノードで、入力 X がそのまま2つの分岐に流れていきます。分岐ノードを含む機械学習ネットワーク以下のように表すとします。計算グラフは最終目的地が損失計算となっており、分岐した流れもどこかで合流するはずです。このことを考慮して以下のような計算グラフを考えてみます。
入力データ X が Aノードで分岐する。
B ノードで X は Y1 に変えられ、 C ノードで X は Y2 に変えられる。
Y1 と Y2 は D ノードで合流し Z として流れていく。
X X Y1 Z
B
-----------------------|-----------------------
-----------| A(分岐ーノード) D |---------------------
-----------------------|-----------------------
C
X Y2
この計算グラフを以下のようなノード関数で表すとします。
Y1 = f_b(X)
Y2 = f_c(X)
Z = f_d(Y1, Y2)
さて、上の計算グラフの見方を変えて、分岐は無いものとして考えます。つまり B ノードと C ノードという2つのノードを、BC ノードという1つのノードと考えます。そして BC ノードからは、Y1+Y2 という2つ分の出力データが流れ出ていると考えます。
BC ノードには、X が入り、 Y1+Y2 が出ていきます。
Y1+Y2 は本来2つのデータである Y1 と Y2 をマージしたようなデータです。
X X Y1+Y2 Z
A BC D
-----------|-----------------------|-----------------------|---------------------
この計算グラフは以下のようなノード関数で表すことができます。
Y1+Y2 = f_bc(x)
Z = f_d(Y1,Y2)
合成関数の偏微分の公式より、Z は 変数Y1,Y2の関数であり、Y1,Y2は変数Xの関数と考えられるから、以下のように表されます。これはもともとの計算グラフの A ノードにおいて、分岐ノードの誤差逆伝播は上流からの勾配の和になる、ことを示している式となります。
\begin{align}
&\frac{\partial Z}{\partial X} = \frac{\partial Z}{\partial Y1} \frac{\partial Y1}{\partial X} + \frac{\partial Z}{\partial Y2} \frac{\partial Y2}{\partial X}
\end{align}
今回は以上です。