Sparse Networks from Scratch論文まとめ
- 2019年7月にsubmitされた論文で、再学習を必要としないSparseなネットワークの構築手法を提案・考察しています。
- arXivはこちら
- Sparseは「疎」を意味し、Denseと対になる概念です。
この記事で説明すること
- この論文で提案されているSparse Learningという学習手法
- 実験結果の概要
この記事で説明しないこと
- Momentumについての説明
- 実験で用いられたモデル構造の詳細
ザックリとまとめていきます。サクッと見ていきましょう。
Sparse Learningとは
この方法の一番の特徴は、論文名にもあるように**「Scratchから疎なネットワークを学習する」という点にあります。
従来の疎なネットワークの学習手法としては、まずはじめに密なネットワークを学習・構築してから、重要な部分だけを残すという手法がほとんどでした。
つまり、従来の手法では密なモデルの学習が必須**であり、学習にかかる時間がネックでした。
では、以下で具体的にはどのようにして学習を行うかを見ていきたいと思います。
Sparse Learningは大きく分けて以下の三つの手順に分解されます。
- モデル全体に対する各レイヤの貢献度を計算する
- 貢献度の低いweightsを取り除く(Prune)
- 2で取り除いた分を再分配する(Redistribution)
これらの計算・操作は各Epochの最後に行われます。
では各手順についてもう少し詳しく見ていきます。
1.モデル全体に対する各レイヤの貢献度を計算する
まずこの「貢献度」とは何か、という部分ですが、これは**「Error(損失)をWeigths(重み)で偏微分したもの」**ということになります。
この偏微分は**「その重みの変化によって生じる損失の変化」**なので、これの絶対値が大きいほどその重みの貢献度は大きい、という事ができます。(例 貢献度100の重みは、重みの1の変化で損失を100変化させる影響力があるが、貢献度1の重みは、重みの1の変化で損失を1しか変化させる事ができません。)
しかし、この偏微分をそのまま貢献度に使うことは危うい手法です。
なぜなら確率的勾配降下法(SGD)では各ミニバッチにおいて偏微分の値が大小・正負で振動してしまうからです。
そこで、ある時点の偏微分をそのまま使うのではなく、時間方向で平均をとったものを利用します。
更に、近い過去の結果を遠い過去の結果よりも優先するために、下記の式で重み付けを行います。
上式がまさにMomentumであり、論文内ではSparse Momentumと呼ばれています。
また、αはsmoothing factorと呼ばれ、過去の影響の大きさを決定する値(ハイパーパラメータ)になっています。
Sparse Learningの最初のステップは上記のSparse Momentumを各レイヤごとに求める、ということになります。
2. 貢献度の低いweightsを取り除く(Prune)
次のステップでは、貢献度の低い重みを各レイヤから取り除いていきます。
1で計算した各レイヤの貢献度を利用して、そのレイヤから幾つの重みを取り除くかを決定します。
更にそのレイヤの中で絶対値が小さい重みから取り除いていきます。(モデル全体における各レイヤの貢献度→各レイヤ内での重みの絶対値の順で見ます)
この時にモデル全体から幾つの重みを取り除くかを決める割合をpruning rateと呼び、論文内の図では50%を例にしています。このpruning rateは学習率のように次第に減少させていくことが良い結果をもたらすと述べられています。
3. 2で取り除いた分を再分配する(Redistribution)
最後に、2で取り除いた重みの中から、大事な重みを復活(Regrowth)させます。
具体的には、2と同じようにして復活させる重みの数を計算し、2で取り除かれた重みの中から絶対値の大きいもの順で復活させます。
Sparse Learning操作方法まとめ
ここで一度まとめておきます。
- 各レイヤのモデル全体に対する貢献度(Sparse Momentum)を計算する
- 貢献度に基づいて各レイヤから取り除く重みの数を決定し、絶対値の小さい重みから順に取り除く
- 貢献度に基づいて各レイヤで復活させる重みの数を決定し、絶対値の大きい重みから順に復活させる
2、3の手法は同じような計算をしている事がわかると思います。
実験結果概要
この論文内ではMNIST,CIFAR-10,ImageNetでの精度、CIFAR-10での実行速度についての結果が示されています。
MNISTではベースモデル・他のSparse Networkよりも精度が高い事が示されており、CIFAR-10でもほぼ同様な結果になっています。
しかし、ImageNetでは他のSparse Networkよりも同じWeights数で高い精度を出しましたが、ベースモデルの精度には到達しませんでした。(pruning rateを下げれば対応可能?)
CIFAR-10でのSpeedに関しては、従来のモデルよりも数倍早くなった事が示されています。
補足・所感
- MNIST,CIFAR-10に関して、Sparse Learnigはモデルの表現力が下がる方法だと思っていたので、ベースモデルよりも高い精度が出るのは不思議でした。しかし、一つ一つの重みがタスクに対してより適切な表現を獲得した、と考えると単純に表現力を下げる操作ではないなと思い腑に落ちました。(自己解決)
- 上記に関連してですが、Sparse LearningはDropoutと近い方法な気がします。(Weightsの総数、つまりモデルの潜在的な表現力は下がっても、それぞれの重みが"優秀"になり、実際の表現力は高くなる)
- 本記事ではRegrowthを個人的なイメージから「復活」と表現しましたが、「再生」とかの方が本来の英語の意味には近いと思います。
最後までお読みいただきありがとうございました!
ご質問等あればお気軽にコメントください!