LoginSignup
13
6

More than 3 years have passed since last update.

[論文読み + α] 宝くじ仮説を試してみる

Last updated at Posted at 2019-12-19

こんにちは。@YusukeKanaiです。
この記事はACCESS Advent Calendar 2019 20日目の記事です。

この記事では論文読み編と実践編の2部構成となっております。
論文読み編ではICLR2019で発表されたTHE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS (J. Frankleら)(Best Paper Award受賞)の論文を紹介していきます。
実践編では上記の論文をベースにVGGベースのニューラルネットワークモデルで追試した方法とその結果を紹介します。

論文読み編

概要

J.Franleらの論文はニューラルネットワークの一部を間引くpruningに関する研究になります。

ニューラルネットワークのpruning ニューラルネットワークのpruning(図はReference 2.の論文中から引用)

このようなpruningを行えば学習するパラメタ数も減るし、FeedForwardの処理も早くなりえます。一方でネットワークをシンプルにする分だけ精度を犠牲にしているんじゃないのかと思われるかもしれませんが、実際にはそんなに精度を下げることは無いようです。90%も間引いても十分使える場合もあります。

しかしながら、そんな有力なpruningが実用化されていません。理由は最初から間引いて良い具合に良い精度になる学習結果を得るのが難しいからです。

J.Frankleらの研究の成果はそんな難しいpruningの仕方がニューラルネットワークの最初に与える初期値と関係することを発見したことになります。

宝くじ仮説とは?

J. Frankleらはpruningに関して 宝くじ仮説 という仮説を立て、以下のように説明しています。

乱数で初期化された密なニューラルネットワークに含まれるサブネットワークが単体で学習されたとき、その精度
は高々元のニューラルネットワークと同数の繰り返しの学習で元のニューラルネットワークの精度に匹敵する。(そんなサブネットワークが存在する)
\exists m \mbox{ s.t. } j' < j \mbox{(iteration) and } a' > a \mbox{(accuracy) and } \|{m}\|_0 \ll \|{\theta}\| \\
m \text{: mask,  } \\
j , j' \text{:iteration without mask, iteration with mask,   } \\
a, a' \text{:accuracy without mask, accuracy with mask,  } \\
\|{m}\|_0, \|{\theta}\|: \text{:num of pruned weights, num of total weights.  }

何故 宝くじ というのでしょうか? それは(特に巨大な)ニューラルネットワークには大量のクジ(大量のネットワークの組み合わせ)が含まれていると比喩されることから始まります。
初期値を与えたニューラルネットワークを学習していくと、学習が進むに従って不要なリンクは小さい重みの値となりやがては間引かれるような状態になっていきます。逆に言えば当たりクジが浮き彫りになっていくわけです。

ニューラルネットワークのPruningの方法

論文ではiterative-pruningを採用しています。これは徐々に重みを間引いていく方法です。論文中では一回で間引く(oneshot-pruning)よりも良い結果がえられたと報告しています。

Iterative pruningの手順

  1. ネットワークに初期値 $\theta _0$を与える
  2. $j$回の学習を繰り返す
  3. 重み値の小さい方から$p$%のpruningを行うmask[$m$]を生成する
  4. mask[$m$]を初期ネットワークに適応して2.に戻る

ではオリジナルのネットワークとこのpruningで得たネットワークとの比較、またpruningと初期値の関係をみていきましょう。
論文では6種類のニューラルネットワーク(Lenet-300-100, CNN-2, CNN-4, CNN-6, Resnet18, VGG19)を検証していますが、全て載せると記事も巨大になるので4種類(Lenet-300-100, CNN-2, CNN-4, CNN-6)の結果を紹介していきます。

スクリーンショット 2019-12-19 9.24.28.png
(図はReference 1. の論文中から引用)

残りは興味があれば論文を読んでみてください。

Lenet-300-100(Full-Connected Network)の当たりくじ

有名なデータセット MNIST で検証しています。

Pruningによる精度向上

Pruningなし(100%), Pruningあり(weight残り 51.3%, 21.1%, 7.0%, 3.6%, 1.0%), 再初期化ありPruningあり(weight残り 51.3%, 21.1%)で比較した結果です。

fcn_results.png
(図はReference 1. の論文中から引用)

結果から以下の3つのことが言えますね

  1. Pruningあり(Weight残り21.1%)が最も精度よくオリジナルのネットワークよりも精度が良い
  2. あまりPruningし過ぎてもネットワークを悪くするだけ(Weight残り1.9%)
  3. 初期の重みを変えてPruningしても精度向上が見込まれない(見込まれにくい)

Pruningによる学習速度の向上

十分に学習されたと判断できる時に学習を止めた場合の繰り返した学習回数をみてみましょう。

スクリーンショット 2019-12-17 21.39.50.png
(図はReference 1. の論文中から引用)

グラフはちょっと見えにくいですが、確かにPruningをすることで繰り返しの学習の回数が減っている様子がわかります。Weight残り13.5%あたりが極小という感じでしょうか。
論文中ではWeight残り21%の時38%ほど早く学習を終えていると報告されています。

また、早く学習を終えたからといって精度が悪くなっている訳ではなくむしろ良い結果を得ています。

Convolutional Networkの当たりくじ

こちらは有名なデータセットCIFAR-10で検証しています。

Pruningによる精度向上

Conv-2では20k回学習を、Conv-4では25k回学習を、Conv-6では30k回学習した時のトレーニングデータ、テストデータで精度を調べた結果を比べて見ています。

cnn-pruning_accuracy.png
(図はReference 1. の論文中から引用)

初期の重みを維持しての98%未満のpruningに置いて、トレーニングデータでの精度はほぼほぼ100%に達しており、この範囲内でテストデータのpruningによる精度がpruningなしに比べて(つまりオリジナルのニューラルネットワークに比べて)向上していますね。
最も良い時でConv-2は3.4%(weight残り4.6%), Conv-4は3.5%(weight残り11.1%), Conv-6では3.3%(weight残り26.4%)向上しています。

一方でそれぞれのネットワークの初期重みを再設定して学習した時にはオリジナルのニューラルネットワークよりも精度が落ちる傾向にあります。

Pruningによる学習速度の向上

こちらもLenet-300-100の時と同様に十分に学習されたと判断できる時に学習を止めた場合の繰り返した学習回数をみてみましょう。

スクリーンショット 2019-12-19 9.15.19.png
(図はReference 1. の論文中から引用)

やはり、初期の重みを維持したままでのPruningありの時の方が早く学習が収束し、その時の精度も高いとることがわかります。最も良い時でConv-2では3.5倍(weight残り8.8%)、Conv-4では3.5倍(weight残り9.2%)、Conv-6では2.5倍(Weight残り15.1%)という結果が得られています。
一方で初期の重みを再設定して再学習すると学習速度も低下し、精度も下がる傾向がある(ただ必ずしも下がる訳ではなくたまたま残ったweightの中に当選くじがある場合もあります)ことがわかります。

どんなくじが当たりくじになるのか?

J.Frankleらの研究ではIterative pruningという手法で重み値の小さい方から繰り返し$p$%ずつpruningする方法でpruning maskを作りました。この削り方が本当に良いのでしょうか。

この問いに対して検証した論文Deconstructing Lottery Tickets: Zeros, Signs, and the SupermaskがH. Zhouらによって発表されています。

H.Zhouらはpruningのmaskの判断基準を9つ用意して比較しています。

  1. 学習後の重みで大きいものを残す
  2. 学習後の重みで小さいものを残す
  3. 学習前の重みで大きいものを残す
  4. 学習前の重みで小さいものを残す
  5. 学習前/学習後の両方の重みで大きいものを残す
  6. 学習前/学習後の両方の重みで小さいものを残す
  7. 学習前後で重みの絶対値の差が大きいものを残す
  8. 学習前後で重みの差が大きいものを残す
  9. ランダム

で結果がこれです。

スクリーンショット 2019-12-19 10.49.37.png
スクリーンショット 2019-12-19 10.49.46.png
(図はReference 5. の論文中から引用)

案の定という感じですが、「学習後の重み値で大きい値のものを残す」「学習前後で重みの絶対値の差が大きいものを残す」のが良いという結果になりますね。

また、pruningのmaskのみを学習するSupermaskというのもH.Zhouらは提案しています。詳細なアルゴリズムの説明は省きますが、先ほどの判断基準をベースに確率的に良いmaskを学習する仕組みになっています。論文中では先ほどの判断基準に加えて、学習前後での+/-符号の要素も加えて比較しています。その結果が以下になります。

mask_ans_initial_weight_model.PNG
(図はReference 5. の論文中から引用)

なんと、重みの初期値と(学習した)pruningだけでMNISTでは95%以上、CIFAR-10では60%以上の精度を実現しています。

実践編

実装はgithubに公開しています.
https://github.com/YusukeKanai/LotteryTicketHypothesis

CIFAR-10というデータセットで今回は検証しています。
また使用したニューラルネットワークは6層の畳み込み層と2層の全結合層で構成されるVGGベースのネットワークになります。学習中の入力値のBatchサイズは48, 勾配法はAdam法としています。
詳細な構成はこちらを参照してください。

もっと良い精度を出すネットワークは提案されていたり、多くのサンプルがあったりしますが、今回は精度を追求するものではないので、ある程度の精度を出せるものでそれほど学習に時間がかからないものを探し出しました。

Pruningの方策ですが、この論文を参考に、学習後の重みが初期の重みと同じ符号で値が大きいものを残すようにしました。

また重みを更新したタイミングで剪定のために重みを0にしていたものが変化する可能性もあるため、更新後にPruning対象の重みを0にリセットさせています。

検証

さて上記の条件のもと、検証進めましたがあまりいい結果を得られていないようです。今回はpruningの方法を1回の学習の結果からpruningのmaskを決定するone-shotの方法を採用したのでそれが良くなかったのかもしれません。紹介した論文にあるようなiterative-pruningの方法は時間が足りず確かめられていませんが、そちらを採用した方がより宝くじ仮説やpruningの正当性が浮き出るとおもいます。

オリジナルのネットワークのパフォーマンス

ぽん!
オリジナルのloss
オリジナルの精度

テストデータでベストな精度は84.82%でした。この結果をみて検証対象のニューラルネットワークとしては十分な及第点と判断しました。

Pruningした初期重みでのネットワークの精度

おりゃ!

初期重みネットワークの精度

確かに良くなっているものもあります(64%のPruningつまり重みを36%残した状態で精度19.10%が最大)...が思っていたほど高くなかったです(中にはオリジナルよりも精度が低いものさえあります)。

Pruningした重みでのネットワークの学習

$p$%のpruningした後に学習し、EarlyStoppingによって学習が停止したときの最も精度の良かったものを記録しました。各々のpruning率に対して10回繰り返し、そのときの精度と学習回数のそれぞれ平均をとってみています。pruning率の変化による精度の変化、学習回数の変化の結果を順番に紹介します。

Pruningによる精度の変化

pruning_accuracy_mean.png

ちょっと見えにくいですが、pruningした方が良い結果を得ています。52%のpruningをした時が最も良い精度になっており、1.4%の精度向上がみられています。また、86%までpruningをしても精度がオリジナルのネットワークに比べて落ちないという結果も得ています。

Pruningによる学習回数の変化

pruning_iteration_mean.png

Pruningをした時の方が学習回数が少なくて済むという結果を得られています。特に先ほどの「Pruningによる精度の変化」の項目でベストだった52%のpruningしたネットワークでは1.17倍速く、オリジナルに比べて精度が落ちない86%のpruningしたネットワークでは1.36倍速くなっています。

重みの初期値とPruning

最後に重みの初期値とPruningの関係をみてみましょう.
以下の3つの条件を比較しています。

  1. オリジナルのニューラルネットワークで学習した場合(=origin)
  2. オリジナルのニューラルネットワークと同じ初期の重みでpruningしたネットワークで学習した場合(=pruning)
  3. オリジナルのニューラルネットワークと違う初期の重みでpruningしたネットワークで学習した場合(=reinit_pruning)

train_loss.png
train_accuracy.png
test_loss.png
test_accuracy.png

思っていたほど差が出ませんでした。トレーニングデータでのlossとaccuracyの差はほぼありません。テストデータでは感覚的にですが微妙にpruningが一番よく(オリジナルよりも良い精度を出し早く学習する傾向にある)、reinit_pruningが一番よくない(オリジナルよりも悪い精度を出し遅く学習する傾向にある)感じがします。

Reference

  1. THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS
  2. Learning both Weights and Connections for Efficient Neural Networks
  3. 【論文検証】ニューラルネットワークのプルーニングについて
  4. ニューラルネットワークにおける「宝くじ仮説(Lottery Ticket Hypothesis)」
  5. Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask
  6. How the Lottery Ticket Hypothesis is Challenging Everything we Knew About Training Neural Networks

最後に

明日は@jyoppomuのデータ解析の記事になります。お楽しみに!

13
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
6