最近ResNet18を使用する機会があり、かつ構造を理解することを迫られたので、色々調べていると以下の表にたどり着きました。
なるほどなるほどこうなっているのね...
入力がこれで出力がこれで...???いっちょんわからん...
7×7や3×3の隣にある数字は何?計算してもoutput sizeにならない...そもそも各層がどのようにつながっているのかわからない...
と多くの疑問を持ったので色々調べたことからこの表を読み解いていこうと思います。
目次
そもそもResNetとは?
ResNetがCNNの一つであるというのはconvやらpoolやらが前出の表に出てきていることからもお分かりかと思います。まずCNNをよくわかっていないという方はこちら の記事がわかりやすかったので読むことをお勧めします。
本来CNNは理論的には層が深いほどより高次の特徴を捉えることができ精度が高まるとされています。しかし層を深くすると学習が進まないという問題 (学習が進まないといえば勾配消失問題が思い浮かびますが、これはdegradation problemらしい) があり、それを解決したのが残差(Residual)を足し合わせる処理があるResidual Network通称ResNetです。具体的には
input→
↓
畳み込み層
↓
畳み込み層
↓
output + input ←
のように畳み込み層を通った値と通る前の値を足し合わせます。この外側のinputの流れをshortcutと呼び、表の
のようなブロックごとにshortcutが作られています。
表を読み解く
表をよく見ると基本的にはデータの流れが分かりにくいために構造が理解できなかったと思います。
そのため表を入出力のデータサイズに着目して上から読み解いていこうと思いますが、
- relu,batchnormなどの直接データサイズに関係ないものについては言及しない
- shortcutについても触れない
ので留意してください。
なおResNetはImageNetを前提にしているので入力サイズは224*224です。(余談ですがpytorchではnn.AdaptiveAvgPool2d((1, 1))を全結合層の前に入れることによって画像のサイズによらず学習ができるようになっているらしいです。)
##conv1
224224の入力が入ると77のカーネル(フィルタ)、stride 2で畳み込まれます。この時入力値を上下左右3ずつpaddingするので230230のデータを畳み込んでると考えて問題ないです。すると1~7マス目,3~9マス目...223~229マス目が畳み込まれ結果として112112の出力が得られます。この操作を64(channel)枚のカーネルで行うため結果11211264のサイズの出力が得られます。
簡潔にサイズの遷移をまとめると
2242241
↓
(2302301)
↓
11211264
となりました。
##conv2_x
この層では最初にpooling層があります。33のカーネルサイズ、stride 2,padding 1なので実質114114から縦横1~3マス目,3~9マス目...111~113マス目の範囲の最大値が抽出され結果565664の出力が得られます。
次に、畳み込み層があり33のカーネルサイズ、stride 1,padding 1で畳み込まれます。今までのように実質5858から縦横1~3マス目...55~57マス目の範囲が畳み込まれ、結局565664の出力が得られます。
このように3*3のカーネルサイズ、stride 1,padding 1で畳み込むと入力サイズと同様な出力が得られるので今後説明は割愛します。そして、ここではこのような層が後三層続きます。
この層でも同じようにサイズの遷移をまとめると
11211264
↓
(11411464)
↓
565664
↓
...
↓
565664
となりました。
##conv3_x
以下畳み込み層が続きますが、表の説明をよく読むと各層の先頭の畳み込み層だけstrideが2であることがわかります。
また、64から128に横の数字が増えていますが、これがカーネルの枚数だと辻褄が合わないことから出力のチャネル数だということが推測でき、pytorchの実装コードを読むことで確信に変わると思います。
入力のチャネル数と出力のチャネル数が変わってもいいの??
とイメージがわかない方はこちらを読むことでイメージの助けになると思います。
この層の先頭畳み込み層では33のカーネルサイズ、stride 2,padding 1出力チャネル128なので実質5858から縦横1~3マス目,3~9マス目...55~57マス目の範囲の最大値が抽出され結果2826128の出力が得られます。
以下三つの畳み込み層はconv2_xの最後三層と出力チャネル以外同じなので省略してサイズの遷移をまとめると
565664
↓
(585864)
↓
2828128
↓
...
↓
2828128
##conv4_x
conv2_xと同じなので省略してまとめると
2828128
↓
...
1414256
##conv5_x
同上
1414256
↓
...
77512
##avg pool及びcf(全結合層)
avg pool(Average pooling)層で77のそれぞれの要素の平均を取って11512にしています。また、全結合層で111000にしています。1000はImageNetのクラス数です。
まとめると
77512
↓
11512
↓
11*1000
よってResNet18のすべての工程が読み解けました!!
終わりに
もともとは表のすべてを解説しようと思っていたのですが面d...時間がないので辞めました。ResNet34はResNet18とほぼ同じ構成、ResNet50以上はソースコードのBottleneckクラスあたりを読めば解読できると思います。