モデル刈り込みという概念があります。私はそれに関する知見が全くないので、どのような文献などが存在するかは知りません。ただ、興味を持ったので三時間で実験してみました。
脳の老化や細胞死を簡単なニューラルネットワークで表現してみましょう。統合失調症やアルツハイマーの患者の脳は萎縮していますが、水頭症でも知能が正常であるように脳の体積とタスク処理能力は比例するとは限りません。では、脳のどういった特性がタスクに影響するのでしょうか。ここでは簡単なタスクを用いて、その性質を確かめてみましょう。
まず、MNISTを数層のMLPに入力します。まず、比較として小さいMLPで学習させましょう。
当然どれも極めて制度がいいですね。
ではつぎに、これらのモデルの構造を深く調べてみましょう。ここではネットワーク科学の指標を用いて、モデルの特性を理解してみることにします。
↑
さすがに非現実的ですね。実際の脳構造とはかけ離れています。
とりあえずトイモデルとしてあそんでみることにします。MNISTをこのモデルより大きいモデルで学習させましょう。よりパラメータが多い複数の構造のニューラルネットワークを考えます。いったん学習を済ませた後、ランダムにノードを奪っていくことを考えます。ノードを奪うことを、MLPのある行、及び列がすべて消えることで表現します。奪った後のモデルで性能を評価し、さらに奪い...を繰り返して、細胞数やどんな特徴があるノードを奪うと性能がガクンと悪くなるかを分析してみましょう。
なるほど、確かに低下するんですね。面白い。一気に低下しているところがありそうですね。では、
次に、「どのノードを奪うと致命的か?」を中心性スコアや活性出力の分布に基づいて評価しましょう。
線形相関ではなく、特定のノードを奪ったときのdropが高いのが面白いですね。
次に、中心性が高いノードから順に削除したらどうなるか?を考えましょう。
綺麗に差が出ますね。やはりノードの特性は重要そうです。ただ、一定レベルで奪ったらものすごく強い影響が出るのがすごく面白いですね。
対称に、低中心性 or 低活性ノードから削ったらどうなるか?を見てみましょう。
今度は、low-activationのnodeを削除すると強い性能低下が出ることがわかりました。これはかなり面白い結果かなと思います。
次に、少し細胞死させて学習し、また少し細胞死させて学習し...を繰り返しましょう。こうすることで実際の脳の老化を再現できる(気がします)。
Retrainにより性能が上がった....??これがかりこみってやつなんでしょうか?
「再学習しなかった場合と比較した差分」や「どの層のノード損失が最もクリティカルか」を分析してみます。
再学習がちゃんと性能を維持していますね。
layerごとのnode損失を見てみると、二層では変化なさそうですが一部外れ値が見られます。
今回の結果は極めて単純なものですが、少しおもしろい結果が得られました。刈り込みに関する文献を読んでみようと思います。
次に、実際の脳の構造を模倣することを考えましょう。
脳は空間構造があり、ほとんどの細胞は空間内部での未結合していますが、一部の細胞は長距離結合することが知られています。
ノードを5つのグループに分け、ほとんどのノードはグループ内でのみ結合するが、ごく一部のノードはグループ間で結合しているようなMLPを考えます。情報はうち一つのグループにのみ入力され、別のグループから出力が出てくるようにします(視覚野などの模倣)
このMLPを使ってMNISTを学習してみましょう。
Test accuracy: 0.8816
accuracyが悪い!モデルをこのように構築するのは良くないんですね。まあ、今回は関係ないですが...
中心性指標はこういう感じで、あまり面白くない....モデル自体の改良が必要です。
まあそれは将来の仕事として、いったんこれに対する刈り込みをしてみます。
あまり違いはみられないイメージです。ただ、全結合に比べて減衰の幅が少ないのかな?とも思います。
では、先ほど同様中心性が高いもの・低いものから失わせる、や、局所・長距離に特異的に殺す、削除後の再学習をしてみましょう。
コードがうまく動かなかった(のと時間がない)ので、pruningのrandomしか動かしませんでした。わからないけど、先ほどよりも刈込に対する耐性が上がっているような(気がする)。
あと、神経細胞の死は領域特異的であることが多いので、同一グループや、接続が濃い小集団でたくさん死ぬようにしタラいいんじゃないかとは思いましたが、諦めました(時間がない...)
三時間試したことにしてはかすかに面白ことができた気がします。ただ、刈り込みとの比較はできていないので、もう少し検討が必要そうです。
次やることとしては、もう少し生物学的に妥当なモデルを作ることや、同じパラメータ数でも「浅く広い vs 深く狭い」構造の性能比較、DropoutやBatchNorm導入による耐性の評価、ネットワーク科学の指標をもっと用いること、ノード除去時の損失関数の変化量をモニターするなど細かい検討をすること、GNN/Attentionなどほかの構造でも試すことが考えられます。また、MNIST以外のタスクをするとまた違う結果が出るかもしれません。
ぜひ刈り込みの知見や改良の提案など頂けると嬉しいです。
コードはgithubにあります。
https://github.com/ShinMiz/neuro-age-task