こんにちは。Qiitaで記事初投稿となります!
最近扱っているモデルであるRNN、LSTM、GRUについて理解を深めたい..!ということで、自分の勉強がてらに深層学習の仕組みから各モデルの特徴までまとめていきたいと思います!
アウトライン
- 深層学習とは ※既知の方はどうぞ飛ばしてください!
- RNN
- LSTM
- GRU
1. 深層学習とは
まず、モデルの中身に入る前に、大前提となる深層学習の仕組みについて簡単におさらいします。
深層学習は機械学習の一種です。
機械学習は、機械が大量のデータを学習し、判断・処理の精度を上げることを目指す手法です。一般的に、データを最もよく説明できる関数(数式)を探索するよう指示して、その探索を機械に任せるイメージです。
対して、深層学習は多層構造のネットワークを通じて、機械が重要な特徴・指標を自ら判断したうえで学習を行い、判断・処理の精度を上げることを目指す手法です。これにより、複雑なデータのパターンをより効果的に捉えることが可能になります。
例えば、従来の機械学習では、画像認識で「犬」を判別する場合、「耳の形状」「鼻のサイズ」といった特徴を手作業で定義する必要がありましたが、深層学習ではそれを自動的に見つけ出し、より高い精度で分類を行ってくれます。
ここでいう「深層」とは、関数を何重にも積み重ねて「層(layer)」で表現をすることに由来しています。つまり、深層学習で関数を積み重ねることによって複雑な関数を表現できるようになるということです。
例えば、深層学習の層構造を$f^{(n)}(...f^{(2)}(f^{(1)}(x))...)$のように表現されるとします(例:$f_2(x)=f(f(x))$)。
この関数$f^{(i)}$について、最もシンプルな関数である1次関数を重ねたところで、結局1次関数に変わりありません。深層学習においては、1次関数以外の関数が必要になります。
ちなみに、1次関数はグラフが真っ直ぐになるので線形であると言い、それ以外の関数を非線形であると言います。ここから、層を表す関数は非線形であることが必要になるということが分かると思います。
このような非線形関数を活性化関数$\varphi$と呼び、深層学習では$f^{(i)}(z)=\varphi(a_1z_1 +a_2z_2 +...+ a_mz_m + b)$の形で表されます。活性化関数にはいろいろな種類がありますが、ここではその説明は省略して、深層学習の数学にもちょっとだけ触れたいと思います。
深層学習の中身
この図は、3層構造の多層ニューラルネットワークを示したものです(一般的に重みがリンクしていない入力層はカウントされません)。
図のように、入力層 → 中間層(隠れ層) → 出力層と順番に情報を伝達していき、入力されたものが何であるのかを考え、答えを出力することが基本的な学習の流れです。この入力が与えられたときに各層を順番に計算していき、出力までの計算を行うことを順伝播といいます。
では、前の層からその次の層へ情報を伝達する際、図中の「線」と「丸」はどのような動きをしているのでしょうか。
なお、この線と丸にはいろいろな呼び方がありますが、ここではそれぞれノードとシナプスと呼ぶこととします。
線形変換
この図には、2つの層があり、ノードとノードを接続するシナプス(線)が書かれています。シナプスの上には、両端のノード間の結合重みを表しています。
入力層のノードが持つ値は、結合重みと掛け合わされ、出力層のノードに伝わります。出力層の 1 つのノードには、複数のノードから計算結果が伝わってくるので、これらを全部足し合わせることになります。
具体的には以下のような計算をしていることになります。
- $u_{11} = w_{11}h_{01} + w_{12}h_{02} + w_{13}h_{03} + b_1$
- $u_{12} = w_{21}h_{01} + w_{22}h_{02} + w_{23}h_{03} + b_2$
これが、線形変換、つまり重みを掛ける操作のことを指します。
さて、何気なく重みという言葉を使っていましたが、重みとは、各ノードが持っている値で、「どれだけこの情報が重要であるか」を表します。
例えば、白ワインか赤ワインかを見分けなければいけない状況において、「年数」「アルコール度数」「色合い」という3つの情報が直前のノードから得られたとします。「色合い」という情報があればすぐに種類を判断できますが、それ以外の情報では二者を見分けることはできません(大体のワインは色で赤/白を見分けられるはず..)。この場合「色合い」という情報の重みを大きくすることで、より精度高く赤か白かを見分けることが可能になります。
非線形変換
隠れ層では、一つ前の層に線形変換を適用した結果を受け取り、そこへさらに非線形変換を適用したものを出力します。
この図では、各層において線形変換のあとに非線形変換を施していることが分かります。
ここで、非線形変換をする際に使われる活性化関数とは一体何なのでしょうか。
それは、「その情報を次に伝達すべきか否かを判断するフィルター」のようなものをイメージしてください。
例えば、一般的な活性化関数の一つであるReLU(Rectified Linear Unit)関数は、入力が負の場合には出力は0、正の場合には入力をそのまま出力する関数です。
ニューラルネットワークは人間の神経細胞を模して設計されており、活性化関数は入力信号を肯定するもの、否定するものと考えられます。つまり、データを入力して活性化関数に通すと、それが白ワインであるとされる閾値を超えた途端、入力信号が肯定されます。このことをニューラルネットワークの文脈ではよく「発火」するといいます(正直呼び方に慣れないですが、閾値を超えた場合に信号が伝達される感じです)。
このように、各層において、線形変換に続いて非線形変換を施し、層を積み重ねて作られるニューラルネットワーク全体としても非線形性を持つことができるようにしています。
目的関数の最適化
この後、「実際に答えと照らし合わせて、正解であったかどうか」「不正解の場合、どこを修正すれば正解に近づけるのか(どのノードの重みをどう変更するべきか)」を学習していきます。
ここで、損失関数(目的関数)とは、「AIの予測と正解がどれくらい違っていたか」を求めるための関数です。損失関数の出力結果は損失と呼ばれ、この損失を最小化(最適化)するパラメータの調整方法に「勾配降下法」が用いられます。
図にも示しているように、分類問題の場合、損失関数として交差エントロピーが、回帰問題の場合、平均二乗誤差がよく使われます。ここでは、計算の分かりやすさの面から、回帰問題を考えていきます。平均二乗誤差とは、個々の実測値と予測値の差の二乗を平均した値であり、これをできるだけ小さくすることを目指します。
この図においては、パラメータ$w$を変化させた際の目的関数$L$の変化を二次関数で表しています。
初期値4に対する接線の傾き(勾配)が5で正の場合、負の方向に$w$を変化させていくと、最小値に近づいていくことが分かります。どのように近づけていけるかというと、現在の$w$から傾きを引いていくと逆方向への動きが実現できます。反対に、傾きが負の時は、更新量を足せば正の方向に変化できますね。
この傾きが正と負の2パターンの動きは次の式で表せます。
$$更新後のw = 更新前のw - wの更新量$$
この時、更新する幅は学習率というもので調整していきます。
つまり、学習率と勾配の積を更新量としてパラメータを変化させることで、目的関数$L$を最小にする$w$へと徐々に近づけることができます。
これを繰り返していき、重みを調整することで、この損失をいかに小さくするかが学習の根幹となります。
ニューラルネットワークの構造に先ほどのパラメータの更新を落とし込むとこのようになります。
例えば、$w_2$についての$L$の勾配は、$\frac{\partial L}{\partial w_2}$であり、これは合成関数の偏微分なので連鎖律を用いて$\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_2}$のように展開できます。2つの偏微分の積ですね。
同様に、$w_1$に関しては、$\frac{\partial L}{\partial w_1}$であり、連鎖律を用いて$\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial u} \cdot \frac{\partial u}{\partial w_1}$と表されます。
この計算は、層ごとに独立しており、前の層や中間結果が再計算されることになります。
つまり、損失関数$L$が$w_1$に影響を与える経路をすべてたどる必要があるということです。
- 損失関数$L$は最終出力$y$に依存している → $\frac{\partial L}{\partial y}$
- $y$は隠れ層の出力$h$に依存している → $\frac{\partial y}{\partial h}$
- $h$は隠れ層の入力$u$に依存している → $\frac{\partial h}{\partial u}$
- $u$は$w_1$に依存している → $\frac{\partial u}{\partial W_1}$
このように、各勾配を計算するたびにすべての依存関係を一から追跡し、同じ中間結果(例 h,u,y)を再計算する必要があります。
今まで見てきた方法は、損失関数$L$に対して、パラメータ$w_1$、$w_2$、$b_1$、$b_2$を、直接微分する方法でした。
これは、層ごとに個別に微分しており、膨大な計算コストがかかってしまいます。
このようなやり方では、多層ニューラルネットワークでは現実的ではないため、損失関数が出力した結果を利用して、出力層側から入力層側へと逆方向に伝達していく方法がよく知られています。
これを、逆伝播(バックプロパゲーション)と呼びます。どのような計算がされているかを見ていきましょう。
逆伝播(バックプロパゲーション)
この図は、先ほどまで見ていた3層のニューラルネットワークを別の書き方で表したものです。新しい入力$x$が与えられたときに、線形変換、非線形変換を施されていき目的関数の値$l$を計算している順伝播の様子が分かると思います。図中の丸いノードは変数を表し、四角いノードは関数を表しています。
次にやりたいことは、パラメータの更新ですね。各パラメータ$w_1$、$w_2$、$b_1$、$b_2$に記載の数式はパラメータの更新式です。学習率は決まった値だとして、目的関数の偏微分の値(例えば、$\frac{\partial L}{\partial w_1}$)が求まれば計算できます。
ここでは、$w_1$と$w_2$の更新量を考えてみましょう。
最初に$w_2$の目的関数に対する偏微分の計算です。
次に、$w_1$の目的関数に対する偏微分の計算です。
ここで注目することは、必要な目的関数の勾配は、各パラメータ($w_1$または$w_2$)のノードより先の部分(出力側)にある関数の勾配をかけ合わせれば計算できるということです。
例えば、$w_2$の偏微分値$\frac{\partial L}{\partial w_2}$は$\frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_2}$ですが、これは$w_2$のノードより先の2つの関数の勾配です。
つまり、順伝播とは逆向きに、各関数における入力についての勾配を求めて、掛け合わせていけば、パラメータについての目的関数の勾配が計算できます。このアルゴリズムを誤差逆伝播法と呼びます。
これで深層学習の仕組みについては終了です!
2. RNN
お待たせしました。ここからRNN(再帰型ニューラルネットワーク)の説明をしていきます。
今まで見てきた通常のニューラルネットワークでは、ある層の出力は次の層の入力に利用されるのみでした。
対して、RNNは、過去の情報を利用して現在および将来の入力に対するネットワークの性能を向上させる構造を持っています。
仕組み
RNNの隠れ層において、再帰的に出現する同じのネットワーク構造のことをセル(cell)と呼びます。この図は1つの時刻における計算フローを示しています。
流れとしては以下の3ステップです。
-
隠れ層の更新
$$h_t = f(Ux_t + Wh_{t-1})$$
前の時刻$t-1$での隠れ層の状態$h_{t-1}$と現在の時刻$t$の入力データを$x_t$を組み合わせ、活性化関数$f$を適用することで、次の隠れ層の状態$h_t$を計算します。
なお、時刻$t$の入力$x_t$には重み行列$U$をかけ、入力データを隠れ層用の情報に変換し、前の時刻の隠れ層の状態$h_{t-1}$にも重み行列$W$をかけ、過去の情報を反映しています。
-
出力層の線形変換
$$o_t = Vh_t$$
隠れ層の状態$h_t$を重み行列$V$を用いて線形変換し、出力層の中間値$o_t$を計算します。
隠れ層の情報を出力用の次元に変換する役割を果たしておりますが、この時点ではまだ最終的な出力形式(確率や実数値)にはなっていません。
-
活性化関数の適用
$$\hat{y}_t = g(o_t)$$
出力層の中間値$o_t$に活性化関数$g$を適用し、予測値$\hat{y}_t$を得ます。
活性化関数$g$はタスクに併せて出力の形式を変えることが役割です。例えば、分類タスクの場合はソフトマックス関数を、回帰タスクの場合は恒等関数(何も変換しない関数)を適用してそのまま実数値を出力します。
このように隠れ層の状態$h_t$を更新しながら、入力データを逐次処理しています。
次のこの図では、RNNが時系列データ$x_1, x_2,...,x_t$をどのように処理するかを示しています。
ここから、時刻$t$ごとの隠れ層状態$h_t$が入力$x_t$と過去の状態$h_{t-1}$に依存していること、時刻ごとの出力$\hat{y}_t$はどれぞれの$h_t$を基に計算されていることが分かります。
RNNは誤差の逆伝播計算を行うとき、層をさかのぼるにしたがって誤差が急速に小さくなり学習が進行できないという勾配消失問題を抱えています。これは、活性化関数を微分することによって得られる緩やかな傾きや同じ重みを何度も掛け算することによって引き起こされます。
また、誤差が大きくなりすぎてしまい学習が不安定になる勾配爆発という問題も存在します。
3. LSTM
LSTM(Long Short-Term Memory)は、RNNの一種であり、長期的な依存関係を学習できるモデルです。過去の情報を長期間覚えておくことが得意なモデルなので、例えば、長い文章で冒頭話題になった内容を後半で再び使う場合に、「話題を覚えておく力」を活用して、文章の意味をより正確に理解できます。
RNNが抱える「長期記憶の消失問題」を改善するために、記憶セルとゲート機構が導入されています。
この図はLSTMの全体構造を示したものです。LSTMはRNNを拡張したモデルでありますが、いくつかの重要な違いがあります。
-
ゲート機構
LSTMでは、再帰予測を繰り返す中で、長期間にわたる不要な記憶を少しずつ消去し、必要な情報を保持する仕組みを「ゲート機構」によって実現しています。ゲートが分かりづらい場合、伝達される情報の量を調整する出入り口をイメージしてください。
ゲート機構には以下の3種類があります
- 忘却ゲート:長期記憶$c_{t-1}$のどの情報を忘れるかを調整
- 入力ゲート:入力$x_t$と過去の状態$h_{t-1}$を受け取り、新しい情報をどれだけ記憶するかを決定
-
出力ゲート:更新された長期記憶$c_t$から、次の隠れ状態$h_t$にどの情報を出力するかを調整
-
長期記憶$c_t$
RNNでは短期的な記憶として隠れ状態$h_t$がありましたが、LSTMではこれに加えて、長期間の情報を保持するためのセル状態$c_t$を導入しています。
-
活性化関数の役割
RNNでは、過去の情報と現在の情報を非線形に組み合わせるために活性化関数tanhが使われていました。LSTMでは、$σ$と$tanh$を併用されています。
-
シグモイド関数($σ$)
忘却ゲート、入力ゲート、出力ゲートで、情報を保持・消去・出力する割合を調整。座標点(0, 0.5)を基点として点対称となるS字型の滑らかな曲線で、0~1の間の値を取る。 -
双曲線正接関数($tanh$)
新しい情報を生成したり、セル状態をスケーリングする際に使用。座標点(0, 0)を基点として点対称となるS字型の滑らかな曲線で、-1~1の間の値を取る。
では、その仕組みを順を追って見ていきましょう。
仕組み
1.忘却ゲートの計算
忘却ゲートの出力$f_t$が計算される様子を示しています。
$$f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)$$
忘却ゲートはこれまでの「記憶」(セル状態)$c_{t-1}$からどの情報を忘れるかを決める忘却率$f_t$を出力します。
例えば、記憶セルのデータを全て消去する必要がある場合は忘却ゲートから出力される数値は0、全て残しておく場合は1が出力されます。このデータの忘れる度合いに応じて、0から1までの出力がされます。「記憶の取捨選択」ですね。
2.入力ゲートの計算
2枚の図を使って新しい情報を記憶するシグモイド層(入力ゲート)の計算と、候補となる新しい記憶を生成するtanh層のプロセスを示します。
①どれだけ新しい情報を追加するかを決める入力率$i_t$を出力
$$i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i)$$
②入力ゲートの値の分だけ、記憶セルに保存するための新しい情報として$\tilde{c_t}$を作成
$$\tilde{c}_t = \tanh(W_c x_t + U_c h_{t-1} + b_c)$$
3.セル状態の更新
この図では、忘却ゲートと入力ゲートの出力を使ってセル状態$c_t$を更新する様子を示しています。
長期記憶をどれぐらい忘れるかと、短期記憶を新たにどれぐらい覚えるかを更新します。
$$c_t = f_t \circ c_{t-1} + i_t \circ \tilde{c_t}$$
- $c_t$: 更新されたセル状態(出力ゲートへ渡される)
- $f_t \circ c_{t-1}$:忘却ゲート出力で調整。前ステップまでため込んでおいた長期記憶をどのぐらいステップ$t$で保持し、残りを廃棄するかを調整する忘れる記憶
- $i_t \circ \tilde{c_t}$: 入力ゲートで調整した入力値。長期記憶と短期記憶をどのぐらいセルに保持するかという覚える記憶
- $\circ$:要素ごとの積
4.出力ゲートの計算
2枚の図を使って、出力ゲート$o_t$の計算と隠れ状態$h_t$の更新についてのプロセスを示します。
$$o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o)$$
②出力ゲートの出力値$o_t$を用いて、次の時間ステップや隠れ層に渡される短期的な記憶である$h_t$を予測
$$h_t = o_t \circ \tanh(c_t)$$
今まで見てきた流れをざっくり表すと以下のとおりです。
- 忘却ゲートで不要な記憶を捨てる
- 入力ゲートで新しい情報を追加する
- セル状態を更新して「記憶」を保持する
- 出力ゲートで次の時刻へ渡す隠れ状態を決定する
LSTMは、RNNが抱える課題を解決するための重要なステップであり、特に長期的な依存関係を扱う多くのタスクで大きな成果を上げています。その一方で、計算量や学習速度の課題も存在しています。
次に、LSTMを簡素化したモデルで、より高速に学習できるGRU(Gated Recurrent Unit)を見ていきましょう。
4. GRU
この図はGRU(Gated Recurrent Unit)の全体構造を示したものです。
GRUはLSTMの簡易版とされ、計算効率化・軽量化を狙った構造となっています。LSTMと同等以上の精度を保ちながら、以下のような改良により効率を向上させています。
-
LSTMの「入力ゲート」と「忘却ゲート」を統合して「更新ゲート」とする
- 更新ゲート$z_t$:過去の情報をどれだけ残し、新しい情報にどれだけ置き換えるかを制御
-
リセットゲートを導入し、不要な過去情報を除去する仕組みを簡素化
- リセットゲート$r_t$:過去の情報をどれだけ無視するかを制御
- 長期記憶セル$c_t$を廃止し、隠れ状態$h_t$だけに記憶を集約
では、その仕組みを見ていきましょう。
仕組み
1.リセットゲートの計算
リセットゲートは、前時刻の情報をどれだけ忘れるかを決定します。
この値$r_t$は過去の状態をリセットする際に使用されます。
$$r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r)$$
なお、LSTMの忘却ゲートと同じく「前フレームの潜在状態$h^{(t-1)}$をどれだけ忘れるか(どれくらいメモリ上から除去するか)」の役割を担当しています。
2. 更新ゲートの計算
更新ゲートは、入力データ$x_t$と前時刻の隠れ状態$h_{t-1}$を用いて、新たに記憶する情報の割合$z^{(t)}$を計算します。
$$z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z)$$
3.候補隠れ状態の計算
リセットゲート$r_t$の影響を受けた前時刻の状態を用いて、新しい「候補隠れ状態」$\tilde{h}_t$を計算します。
これにより、過去の情報の選別が行われます。
$$\tilde{h}_t = \tanh(W_h x_t + r_t \circ (U_h h_{t-1}) + b_h)$$
4.隠れ状態の更新
更新ゲート$z_t$を用いて、過去の状態$h_{t-1}$と新しい候補隠れ状態$\tilde{h}_t$を加重平均し、次の隠れ状態$h_t$を決定します。
この加重平均により、新しい情報と過去の情報のバランスが調整されます。
$$h_t = z_t \circ h_{t-1} + (1 - z_t) \circ \tilde{h}_t$$
今まで見てきた流れをざっくり表すと以下のとおりです。
- 更新ゲートで「過去の情報をどれだけ保持するか」を決める
- リセットゲートで「過去の情報をどれだけ無視するか」を決める
- 候補状態を生成し、更新ゲートを使って新しい隠れ状態を決定する
GRUは、RNNのシンプルさを保ちながら、長期依存関係を学習できる優れたモデルです。更新ゲートとリセットゲートによる効率的な構造により、LSTMに匹敵する性能を持ちながらも、計算負荷が軽く、実用的な場面で頻繁に使用されます。
ただし、学習対象のデータによっては、どちらのモデルも精度に差が出るので、状況によって適切な方を選ぶことが重要です。
以上の内容をまとめると以下のとおりです。
特徴 | LSTM | GRU |
---|---|---|
構造の複雑さ | 複雑(3つのゲート構造:入力、出力、忘却) | シンプル(2つのゲート構造:更新、リセット) |
パラメータ数 | 多い | 少ない |
計算効率 | 比較的低い | 高い |
モデルの性能 | 長期依存関係の学習に適している | 多くのタスクでLSTMに匹敵 |
実用性 | 計算リソースが十分な場合に選択される | リソース制約下での選択肢 |
実用例 | 自然言語処理(NLP):文章生成、機械翻訳 音声認識:長時間の音声データの処理 |
時系列データ:気象データ予測、株価予測 モバイルデバイス:リソース制約下でのアプリケーション(例:チャットボット) |
以上です。読んでいただきありがとうございました。
参考記事
以下のサイトを参考にさせていただきました。