本題
これは、ResNetについて考察した4本立て(序、破、Q、:||)の一記事です。
というわけで、今回は前回の続きなのですが、前回は
- 研究者がどうやってこの残差ブロックのアイディアに至ったのか
- そもそもなぜ残差を学習するのか
- なぜResNetが(数学的に)高精度を達成できるのか
この3つの問題定義をしたあと、問題1について
ResNet始祖論文を用いて考察しました。
今回は、問題2について考察していくのと同時に、
の考察をもとに、ResNetの裏側に迫っていきます。
まず、ResNetの根本はショートカットを用いたモデル劣化の抑制にある。ということは理解していただけたと思いますが、
残差を学習するって、実際どういう意味があるんでしょう?
これを知るために、「残差を学習する機械学習モデルって、ほかにあったっけ?」
と考えると、
あー、「なんか勾配ブースティングとかあったよなあ」
となるわけです。ここで、てことは、ResNetってもしかして勾配ブースティングと関係あるのかな、とナイーブに考えられると思います。
ただ、勾配ブースティングてなんだったっけ?ってなる人がいる(まさに自分のこと)ので、
勾配ブースティングについて少し考察してみましょう。
勾配ブースティングでは、複数の弱学習器を直列につなげて、
個々の弱学習器が、前の弱学習器の予測結果を合わせたものと、実際の正解ラベルの「差異」を学習します。
これがどういう意味を持つかというと、それぞれの弱学習器が、前の学習器たちを合わせた予測と正解ラベルの間に計算できる「ロス」を小さくする方向に動かしていく。ということです。
「ロス」は微分可能な関数であり、一番単純にロスを減少させるときは、ロスの一回微分に負の値を書けたものを足すことで実現可能です。
これは普通にNNを学習するときに重みの更新の際にも使われるので、SGDという名前がついています。
残差学習は、簡単に言えばこれをロスについて行っています。最初(一番はじめの)弱学習器が生成したロスを、後に続く弱学習器達でロスの負の一回微分を計算すること(残差を学習することに同義)で、学習器全体でのロスを減らしているわけです。
これを踏まえた上で、改めてResNetと勾配ブースティングの関係について調べてみると、
Residual Networks Behave Like Ensembles of Relatively Shallow Networks この論文が非常に面白いことをしています。
彼らの言うUnraveled view (b) は、ResNetを、入力が取っている経路の長さ別に並列方向に書き直したものです。
例えば、ResNetに入った入力が、たくさんショートカットを通ったとします。それらが実際にされた演算回数(ここで言う経路の長さ)は、一回もショートカットを行わずに出力までたどり着いたものに比べて非常に小さくなります。
つまり、ショートカットを行ったものに関しては、演算を「されずに」出力まで到達しているのです。
ここで、残差ブロックn個で構成されているResNetを考えると、経路を選ぶ際に、ショートカットを通るか、通らないかの二択をn回行うことになるので、2^n個の経路オプションがあるわけです。
これらの異なる経路オプションをとるものたちが次々に加えられて構成されるアーキテクチャを見ると、たしかにアンサンブルっぽくない?となるわけです。
そこでVeitさんたちは次の実験をしてみました。
- いくつかの経路を切ったら(断線させたら)ResNet/VGGなどの他のモデルはどういう影響受けるの?
結果はすごく興味深いもので、ResNetはいくつかの経路を切られてもほとんど影響を受けなかったのに対し、VGGなどのモデルは影響を多く受けたのです。
さらにVeitさんは以下の実験も行いました。 - 残差ブロックを消すのではなく、複数個入れ替えたらResNetはどうなるの?
結果は右側のグラフですが、これも驚くべきことにResNetはブロックの入れ替えに対してもあまり影響を受けにくいことがわかったのです。
ブロックを消しても入れ替えても、ResNetは影響をあまり受けない、
これはアンサンブル学習の特徴そのものやん!!おもしろ!!ってなりますよね?
これは先程の説明で行くと、勾配ブースティングのアンサンブル学習では、後続の弱学習器がロス関数を減らす方向に動いていたので、それらの数が減らされたり、順番を入れ替えられたりしてもそんなに影響を受けにくいのです。
どうやら、ResNetはアンサンブル学習と深いつながりがあるみたいです。ただ、それぞれの経路の長さは、先程の説明のようにどの経路をとるかはバラバラになっています。ただ、どの経路長をとるものが多いかは計算できて、それをプロットしたものは(a)です。最長の経路が2^50だとすると、経路長の分布は2^25を平均とした左右対称な分布に成っていることがわかります。
ここでVeitさんはさらに次の実験をしました。
- 学習の際に、どの経路長のものが学習により貢献しているんだろう??
結果は(c)のようになりました、
つまり!!実際のResNetの学習に寄与している(偏微分による誤差伝搬による)ものは、比較的短い経路を持つものに偏っている!!!
ことがわかったのです!!
これすごくないですか? 残差ブロックを介すると、誤差伝搬が長くなる!!すげー!!て思っていた人も多いと思います(自分はそうでした)。
ただ、これは間違いで、実際に伝わっている誤差の経路は、ResNetの全長(一回もショートカットしなかったおりこうさん)よりも遥かに短いのです。
##破のまとめ
ResNetは、確実にアンサンブルっぽい挙動をしている。また、それらのアンサンブルは残差ブロックにより、異なる経路長を持つ弱学習器である。
残差学習は、勾配消失そのものを解決しているわけではなく、実際に学習に寄与しているもののほとんどが、短い経路長をもつものである。
というわけで、非常に(自分的には)衝撃的なことがわかった破でしたが、このようなオブザベーショナルな論文ってすごいありがたいな、とおもいました。