【機械学習】誤差逆伝播法のコンパクトな説明
【機械学習 誤差逆伝播法】word2vecメモ (1)
【機械学習 誤差逆伝播法】word2vecメモ (2)
※公式の番号は、上記の記事間で共通です。
本記事は「ゼロから作るDeep Learning2」(以下「ゼロから本2」)のword2vec(4章)の読書メモです。4章ではword2vecのCBOWモデルをEmbeddingノードとEmbeddingDotノードを使ってPythonプログラムとして実装してあります。Pythonプログラムは具体的なものですが、数学的にどのような計算が行われているかが見えにくい部分もあります。ここでは計算メモとして、その辺をもう少し明らかにしたいと思います。
ここでの議論は、3章で説明してあるword2vecの実装の高速化です。必要に応じて「【機械学習 誤差逆伝播法】word2vec (1)」を参照してください。
1章のMatMulノードについては、「【機械学習】誤差逆伝播法のコンパクトな説明」 にまとめてあります。そこで示した公式を、必要に応じて公式番号で参照します。ご参照ください。
#Embeddingノードの勾配
CBOWモデルにおいては、コンテキストを入力します。これは2つのone-hotベクトルです。Forward計算(推論)において、one-hotベクトルと行列$W_{in}$との掛け算は、行の抜き出しに他なりません。わざわざidxをone-hotベクトルに変換して、行列計算を行い$W_{in}$の行を抜き出すのは計算の無駄なので、効率化しようというのがEmbeddingノードです。
\begin{align}
\\
&CBOWモデルの入力idxを次のようにします。(バッチ処理を考慮)\\
&idx=[1,0,3,0]\\
\\
&W_{in}を以下のようにします。\\
\\
&W=W_{in}=
\begin{pmatrix}
w_{11} & w_{12} & w_{13} \\
w_{21} & w_{22} & w_{23} \\
w_{31} & w_{32} & w_{33} \\
w_{41} & w_{42} & w_{43} \\
w_{51} & w_{52} & w_{53} \\
w_{61} & w_{62} & w_{63} \\
w_{71} & w_{72} & w_{73} \\
\end{pmatrix}
\\
\\
\\
&Forwardの計算です。\\
&行列計算は行わず、numpy配列からidx要素を直接抜き出します。\\
&(行ベクトルの抜き出し)\\
\\
&Y = [\\
&[w_{21},w_{22},w_{23}],\\
&[w_{11},w_{12},w_{13}],\\
&[w_{41},w_{42},w_{43}],\\
&[w_{11},w_{12},w_{13}],\\
&]\\
\\
\\
\\
&Backwardの計算です。\\
&Embeddingノードの計算は数学的にはMatMulノードと同じですから、\\
&その偏微分もMatMulノードと同じ計算で行います。\\
\\
&\frac{\partial L}{\partial W} = X^T \times \frac{\partial L}{\partial Y} \tag{3-2}\\
\\
&ここで例えばidx=[1,0,3,0]に対するone-hotベクトルXの転置行列は\\
&以下のようになります。\\
\\
&X^T=
\begin{pmatrix}
0 & 1 & 0 & 1 \\
1 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 \\
0 & 0 & 1 & 0 \\
0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 \\
\end{pmatrix}
\\
\\
\\
&前の層から伝達される勾配を以下のように表します。\\
&\frac{\partial L}{\partial Y} =
\begin{pmatrix}
l_{11} & l_{12} & l_{13} \\
l_{21} & l_{22} & l_{23} \\
l_{31} & l_{32} & l_{33} \\
l_{41} & l_{42} & l_{43} \\
\end{pmatrix}
\\
\\
\\
&\frac{\partial L}{\partial W_{in}} =
\begin{pmatrix}
l_{21}+l_{41} & l_{22}+l_{42} & l_{23}+l_{43} \\
l_{11} & l_{12} & l_{13} \\
0 & 0 & 0 \\
l_{31} & l_{32} & l_{33} \\
0 & 0 & 0 \\
0 & 0 & 0 \\
0 & 0 & 0 \\
\end{pmatrix}
\qquad \qquad \qquad ∵公式(3-2)
\\
\\
\\
&「ゼロから本2」でidxで0が2重にエントリーされている時の処理は、\\
&上書きでなく足し算であることが注意されていますが、確かにそのようになります。
\\
&\qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \\
\end{align}
#EmbeddingDotノードの勾配
EmbeddingDotノードは、$W_{out}$の行列から行ベクトルの抜き出しで得られた多値のベクトルを二値のスカラーに変えるために、行ベクトルとターゲットの単語(say)の内積を求めます。ターゲットの教師データ(Yes/No)も与えられ、Lost関数が計算され、その勾配が計算されることで、パラーメータ$W_{out}$が適切な値に更新されていきます。
EmbeddingDotノードはEmbeddingノードを再利用して実装されるために、$W_{out}$の形状は$W_{in}$に合わせます。つまりここではidxに対応した列を抜き出していましたが、行を抜き出す操作に変えます。MatMulノードでの実装とこの辺が変わりますので注意が必要です。
\begin{align}
\\
&中間層のベクトルです。\\
&X=(x_1 x_2 x_3)\\
\\
&2値分類のターゲットとなるidx=1とします。\\
&(簡単のためにバッチ処理ではなくidx=1という単体で考えます。)
\\
\\
&W_{out}を以下のようにします。\\
&以前のMatMulノードの時とは違って、W_{out}の行と列を転置したものにします。\\
&EmbeddingノードのW_{in}と形状をあわせて、プログラムを再利用するためです。\\
&W=W_{out}=
\begin{pmatrix}
w'_{11} & w'_{12} & w'_{13} \\
w'_{21} & w'_{22} & w'_{23} \\
w'_{31} & w'_{32} & w'_{33} \\
w'_{41} & w'_{42} & w'_{43} \\
w'_{51} & w'_{52} & w'_{53} \\
w'_{61} & w'_{62} & w'_{63} \\
w'_{71} & w'_{72} & w'_{73} \\
\end{pmatrix}
\\
\\
\\
\\
&Forwardの計算です。\\
\\
&X W_{out} の全計算を行いません。 \\
&代わりに idx=1 の行を抜き出して W_t = (w'_{21} w'_{22} w'_{23}) とします。\\
&(W_t=targetWの略です。)\\
&Xとの内積を求めます。\\
&Y=W_t * X = \sum_{k=1}^3 x_k w'_{2k}\\
&この内積値が次のSigmoidWithLoss関数に渡されます。\\
\\
\\
\\
\\
&Backwardの計算です。\\
\\
&\frac{\partial L}{\partial X} =
\frac{\partial L}{\partial Y} \frac{\partial Y}{\partial X} =
\frac{\partial L}{\partial Y} \frac{ \sum_{k=1}^3 x_k w'_{2k}}{\partial X} =
\frac{\partial L}{\partial Y} (w'_{21} w'_{22} w'_{23}) =
\frac{\partial L}{\partial Y} W_t
\\
\\
&(Pythonプログラムでは\frac{\partial L}{\partial Y} = dout です。)
\\
\\
&\frac{\partial L}{\partial W_{out}} =
\frac{\partial L}{\partial Y} \frac{\partial Y}{\partial W_t} \frac{\partial W_t}{\partial W_{out}} =
\frac{\partial L}{\partial Y} (x_1 x_2 x_3) \frac{\partial W_t}{\partial W_{out}} \\
\\
\\
&∵\frac{\partial Y}{\partial W_t} =
\frac{\partial \sum_{k=1}^3 x_k w'_{2k}}{\partial W_t} =
(x_1 x_2 x_3) \\
\\
&\frac{\partial L}{\partial Y} (x_1 x_2 x_3) = (x'_1 x'_2 x'_3)とすると、\\
&W_t=f(W_{out}) は idx=1 の行を抜き出す操作だから、\\
&Embeddingノードで見たように以下のようになります。\\
\\
&\frac{\partial L}{\partial W_{out}} =
\begin{pmatrix}
0 & 0 & 0 \\
x'_1 & x'_2 & x'_3 \\
0 & 0 & 0 \\
0 & 0 & 0 \\
0 & 0 & 0 \\
0 & 0 & 0 \\
0 & 0 & 0 \\
\end{pmatrix}
\\
\\
&\qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \qquad \\
\end{align}
以上でword2vec高速化のコアのノードであるEmbeddingノードとEmbeddingDotノードの詳細な計算メモを終えます。word2vecの全体の動作はやはり「ゼロから本2」をご参照ください。