2
5

More than 1 year has passed since last update.

CNNの改善版?CapsNetを知る

Last updated at Posted at 2022-01-23

何についての記事か?

Hintonさんが2017年に考えていたニューラルネットワークの新しい構造であるCapsule Network(CapsNet)(2017年NeurIPS)について紹介します。実装

2018年には文字分類ではなくSmallNORBに対する分類、色んな視点のテストデータに対してテストをにした論文が発表。
2019年には教師なし学習であるStacked Capsule Autoencoders(NeurIPS 2019)も発表されています。

何がすごいか?

①CNNでは相対的な位置をとらえていなかったが、相対的な位置も考慮可能になる。
②CNNでは全特徴の量(犬度20%猫度70%車度10%)で判定していたのに対し、各特徴の量(犬のしっぽ、ミラー、姿勢、幅)で判定できる。

メリット

・限られたデータで精度を高めることができる。
(・説明性としてはモデルが使いやすいかも?)

デメリット

・CNNと比べて精度が高いわけではない。2013年ごろの性能感。
・物体のクラス数が増えるとカプセル数が増えて学習が難しくなる。
・カプセルの個数の設計が難しい。
・個人的にはPrimaryCapsの解像度が6*6なので、画像が高解像になった時に対応できるのか気になる。

CapsNetが解決したい課題

①CNNのPoolingにより相対的な位置関係が失われる問題を解決しようとした。
CNNは顔のパーツがあれば、相対的な位置がおかしくても「顔」と識別してしまいます。
image.png
参考

HintonさんはPooling処理には以下のような問題意識をお持ちのようです。

Hinton: “The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster.”

この相対位置関係の情報を消失してしまう問題に対してはTransformerのPositional Encodingなどがあります。
Hintonさんは別のアプローチで解決しようとしました。

②CNNは人間の知覚システムを模擬できていない
個人的には①の観点よりもこちらの観点が興味深いです。
人間は以下のような決定木によって分類を行っているらしい。
CNNはそれとは異なる方法により分類しているので、人間には理解しにくい現象がたびたび発生する。CapsNetではAとBがあり、AとBが○○の位置関係か→YES!で分類できるようにすることにより人が理解するのと同じようにモノを認識することを目指している。
説明性の観点からはCapsNetの方が使いやすそう。
image.png
参考

・CNNは大量の学習データが必要
CNNではワークが傾いたり、回転した画像を学習して傾いたワークも同じ数字と教えたりする必要がある。
CapsNetでは文字が何か、角度がどれくらいかの特徴情報が得られるため、傾いたRを学習しなくてもRと認識できる。
image.png

余談
人は判断を本当に決定木でしているのか検索していて認知科学で引っかかった人工知能と認知科学という記事が面白かった。(決定木とは関係ないが。)
機能が実現できれば実現方法は問わないという考えを機能主義というらしい。
空を飛べれば鳥と同じ方法でなくてもよいという考えが機能主義である。
人工知能は人と同じように分類できることを目指した結果、人とは異なる認知機能を持ち、人を超えた。
人の認知機能を超え、異なる認知方法を持ったがために、どうやって認識・分類したか説明しろとAIに言っても人には伝わらなくなった。
つまり、機能主義と説明性は相性が悪い。
「分類精度も人以上の100%で人が分かるように説明しろ」は矛盾した要求なのだろう...
人の認知方法に近づけたモデル構造を最初から検討していないと誰もが満足するような説明性は難しく、その場合は100%の精度ではなく、人に近い認識精度になってしまうのではないだろうか。CapsNetの構造を進化させてもいくつく先はSOTAではないのかもしれないがそれはそれで良さそう。
その他、一般化フレームワーク問題の内容も面白かった。境界を決めないと計算量が爆発してしまう問題。「人に迷惑をかけるな」と言われて悩む問題を言うらしい。

アイデア

物体の見え方(視点や回転)は認識には関係なく、(特徴の)相対的な位置関係情報が重要という考えでモデルが構築されている。
顔が回転しているかは認識に関係なく、目と鼻、鼻と口の位置関係が認識には重要。

The primary capsules are the lowest level of multi-dimensional entities and, from an inverse graphics
perspective, activating the primary capsules corresponds to inverting the rendering process.

レンダリング処理は座標情報から画像を作成を行っている。相対的な座標情報から輪郭の描画とか。
人間は逆に、認識する時に目からとらえた3Dの情報を相対的な位置関係情報に変換し、既に覚えている位置関係情報と比較し物体を分類している(とヒントンさんは考えている)。

論文中ではこの方法(低次元の情報に変換する方法)を「inverse graphics」とよび、このアイデアに基づいてprimary capsules(低次元情報、鼻と口の位置関係がどうかなどの情報)を作成している。
どの、primary capsulesが活性化したかで画像を構築する処理はレンダリング処理に相当し、DigitCapsがこれに該当している。

基となるアイデアとしては非常にシンプルだが、capsulesを学習するアルゴリズムがなく、他に論文が出ていなかった。そこを解決しているのがDynamic routing。

2018論文では色んな視点のテストデータに対してCNNの認識よりtest errorsが45%減っているらしい。

CapsNet構造

CapsNetの構造は以下のような構造になっています。
image.png
①入力画像(1x20x20)をCNN(256*9*9)で畳み込んで256x20x20の特徴マップを出力する。
②8個の畳み込みユニット(32*9*9)で畳み込んで8×6×6のPrimaryCaps(ui:iは1~32*6*6。②の紫の箇所を指す。ui は8次元)を32個出力する。
ここまでは通常のCNNにより低次元の特徴を圧縮して、32個のPrimaryCapsを作る。

③ui(1×8)とWij(8×16)(j:1~10)の積をˆuj|i (1×16)とする。
uiは低次元の特徴を表現するカプセル。
Wは低次元の特徴(例:鼻、口、目。数字だったら横棒、縦棒)と高次元の特徴(例:顔。数字だったら5とか10とか)の関係性を保持している行列で顔は鼻を中心として存在する、顔は鼻の10倍の大きさなどの情報を保持している。
Wとuの積を取ることで検出された目の位置からすると顔はこの位置にあるはずだ、検出された鼻の位置からすると顔はこの位置にあるはずだという情報が得られる。数字だったら、横棒の位置からすると4はこの位置にあるはずだ、縦棒の位置からすると4はこの位置にあるはずだという情報が得られる。

高次元の特徴に対する各低次元のカプセルの情報を合算することで顔の情報(ベクトル)や数字4の情報(ベクトル)が求まる。
合算する際は後述するDynamic Routingにより求められるcijにより、ˆuj|iの情報をどれくらい上位のカプセルに渡すかを調整している。縦棒の位置からすると4はこの位置にあるはずだという情報は4を表現するカプセルにどれくらい渡すか。
合算し得られたベクトルsjに対し以下のsquash関数でベクトルの向きは保持したままノルムを0~1に変換する。
image.png
image.png
これにより、16chの10個のDigitCapsが出力される。
この16chのそれぞれが文字幅や姿勢を表現している。(後述)

④L2ノルムで各クラス(10クラス:0~9)の確率を出力する。

通常のCNNと異なるのは256 x 20 x 20から32 x 1 x 6 x 6ではなく、32 x 8 x 6 x 6を出力している点。
1次元のスカラではなく、8次元のベクトルを出力している点が異なる。Pooling処理もなく、バイアスもありません。
Wが相対的な位置情報を保持している。cで分配を決める。

以下の図は③④のまとめ引用:CapsNet-Tensorflow
image.png

イメージとして書くと以下の図のようになる。
cを学習して求める必要がある。
最初は値は0からスタートする。
image.png

Dynamic Routing

cを学習する方法です。foward時に更新される。Wはこの後で説明する損失関数を使ってBackpropで更新します。
2行目:bijはcの計算途中の値であり初期値は0。lは低次元の下位カプセル。l+1は高次元の上位カプセル。
3行目:r回繰り返す。論文ではr=3。rが大きいと過学習状態になる。
4行目:cの合計は1となるようにする。cは下位カプセル数分(6*6*32)存在する。
初回は0なので、低次元のカプセルが3つの場合は各cの値は0.33となる。
5行目:現状のcの値を使って、上位カプセルに情報を送りsjを作る。
6行目:sjをsquashしてvjを求める。
7行目:現状の重みbijと下位カプセルと上位カプセル(vj)の類似度を足して更新している。
最終上位カプセルと類似度が高くなる下位カプセルの配分がより行われるようにcが更新される。

image.png

損失関数

各カプセルに対してLを求めて合計することでLossを求める。
正解のラベルの時はTk=1、不正解ラベルの時は0。
逆伝搬してWが求まる。
image.png

再構築による正規化

reconstruction用のsubnetworkをつけるとMNISTのerror率が低下する。
正解ラベルと同じ出力ベクタ-を再構築する。
入力画像と再構築画像の2乗誤差が最小になるようにWを更新する。
image.png
image.png

カプセルの各次元が表す情報

幅や歪み、各数字の情報(2の尾の長さなど)などが含まれる。組み合わせで表現されているものもある。
各次元の値を-0.25~0.25の間で0.05ずつ変化させた様子が以下。
image.png

結果

MNIST、training 60K images, test 10K images。
Baselineはこの論文。augmentありで0.21% error。augmentなしだと0.39% error。
CapsNetはaugmentationはなしでerror率は0.25%。
パラメータ数はBaseLineが35.4Mに対して,CapsNetは8.2M(reconstruct subnetworkあり),6.8M(reconstruct subnetworkなし)。
image.png

アフィン変換に対するロバスト性

affNISTに対する正答率を比較。
training中はaffine変換はなし。
BaseLine(CNN base):MNIST→99.22% affNIST→66%
CapsNet:MNIST→99.23% affNIST→79%
image.png

Segmenting

CapsNetはCにより着目するカプセルを決めているので一種のattentionになっている。
5の特徴はこのカプセルから情報を取ってくるからほかのカプセルは無視と。
その場合、文字が重なっていてもそれぞれの文字の情報を分離して抽出できるのではとHintonは考えた。
数字が重なったMultiMNISTを使用。上側:入力画像 下側:再構築画像 R(正解ラベル)、L(予測ラベル)。赤と緑が数字のそれぞれの予測。
CapsNet:5% error

image.png

他のデータセットに対して

総じてDeep先駆けの時代のモデルと同程度の精度感。
CIFAR10は10.6% error。7つのモデルをアンサンブル。構造は同じ。(2013年のモデルがこれくらいの精度)
smallNORBは2.7% error。これは2011年SOTAと同程度。
以下はsmallNORB datasetの一例。
image.png

参考

Understanding Hinton’s Capsule Networks. Part 2. How Capsules Work.
CapsNet の PyTorch 実装

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5