⚙️ ResNet50実装記録 Vol.5:CIFAR-10向けResNetへの構造変更
はじめに
こんにちは、DL(ディープラーニング)を勉強している大学1年生です。
2025年の4月からプログラミングを本格的に勉強し始め、GCI(東大松尾研が提供するデータサイエンス入門講座)を修了し、現在DL基礎講座を受講しています。
今回は、DL基礎講座の最終課題に向けた学習のアウトプットとして、ResNet50を用いた画像分類に挑戦します。
この記事は、自分自身のポートフォリオとして ResNet50の実装と、その精度改善施策 を行うプロセスを記録し、共有することを目的としています。
※この記事は実装と改善のプロセスに焦点を当てており、ResNetの理論的な詳細には深く立ち入りません。理論的背景については、適切な参考文献をご参照ください。
前回の記事(Vol.4)ではTrivialAugmentを導入し、モデルの汎化性能を大幅に改善しました。今回は、精度向上の次のステップとして、モデル自体の構造的なボトルネックを解消することを目指します。
具体的には、CIFAR-10の画像サイズ(32x32)に最適化されていないResNet50の初期層(Stem)を改造します。
🎯 今回の施策:CIFAR-10向けResNetへの構造変更
1. 問題の仮説:情報損失の防止
ResNet50の原著論文や、PyTorchの torchvision で提供されている標準モデルは、224x224ピクセルという高解像度の画像を扱うImageNetデータセットに合わせて設計されています。
今回使用するCIFAR-10の画像は32x32ピクセルという非常に低画質なものです。標準のResNet50の初期層(Stem)は、この小さい画像を急激にダウンサンプリングしてしまい、以下のような問題を引き起こす可能性があります。
| 層 | 標準 ResNet50 (ImageNet用) の設定 | 32x32画像での出力サイズ | 発生しうる問題 |
|---|---|---|---|
| Conv1 |
kernel_size=7, stride=2
|
16x16 | 最初の畳み込みで既に情報が粗くなる |
| MaxPool |
kernel_size=3, stride=2
|
8x8 | さらに画像サイズが1/4に縮小 |
仮説: 特徴抽出が本格的に始まる前に画像サイズが8x8まで縮小されてしまうと、重要な低レベルな情報が失われ、精度向上の妨げになっている可能性があります。
2. 具体的変更:Stem層の改造
この情報損失を防ぐため、ResNetをCIFARタスクで用いる際の一般的な手法に従い、初期のダウンサンプリング処理を緩和します。
| 項目 | 変更前 (標準 ResNet) | 変更後 (CIFAR ResNet) | 目的 |
|---|---|---|---|
| Conv1 |
kernel_size=7, stride=2
|
kernel_size=3, stride=1, padding=1
|
画像サイズを維持しつつ、特徴抽出を行う。 |
| MaxPool | 削除 | 削除 | サイズ縮小をボトルネック層に委ねる。 |
この変更により、最初のステージが終了するまで画像サイズは32x32を維持し、モデルがより多くの情報を保持した状態で学習を開始できるようになります。
3. 実装コードの変更箇所
ResNet クラスの __init__ メソッド内の以下のコードを変更しました。
# 変更前 (標準的な実装)
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 変更後 (CIFAR-10向け Stem構造)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # 3x3, stride 1に変更
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
# MaxPool層は削除し、self.res1の前に挟む層はなしとする
# ... (以降のBottleneck層の定義は変更なし)
承知いたしました。前回の記事からの流れを汲み、モデル構造変更の重要性とその具体的な実装内容をMarkdown形式で記述します。
Markdown
⚙️ ResNet50実装記録 Vol.5:CIFAR-10向けResNetへの構造変更
はじめに
こんにちは、DL(ディープラーニング)を勉強している大学1年生です。
2025年の4月からプログラミングを本格的に勉強し始め、GCI(東大松尾研が提供するデータサイエンス入門講座)を修了し、現在DL基礎講座を受講しています。
今回は、DL基礎講座の最終課題に向けた学習のアウトプットとして、ResNet50を用いた画像分類に挑戦します。
この記事は、自分自身のポートフォリオとして ResNet50の実装と、その精度改善施策 を行うプロセスを記録し、共有することを目的としています。
※この記事は実装と改善のプロセスに焦点を当てており、ResNetの理論的な詳細には深く立ち入りません。理論的背景については、適切な参考文献をご参照ください。
前回の記事(Vol.4)ではTrivialAugmentを導入し、モデルの汎化性能を大幅に改善しました。今回は、精度向上の次のステップとして、モデル自体の構造的なボトルネックを解消することを目指します。
具体的には、CIFAR-10の画像サイズ(32x32)に最適化されていないResNet50の初期層(Stem)を改造します。
🎯 今回の施策:CIFAR-10向けResNetへの構造変更
1. 問題の仮説:情報損失の防止
ResNet50の原著論文や、PyTorchの torchvision で提供されている標準モデルは、224x224ピクセルという高解像度の画像を扱うImageNetデータセットに合わせて設計されています。
今回使用するCIFAR-10の画像は32x32ピクセルという非常に低画質なものです。標準のResNet50の初期層(Stem)は、この小さい画像を急激にダウンサンプリングしてしまい、以下のような問題を引き起こす可能性があります。
| 層 | 標準 ResNet50 (ImageNet用) の設定 | 32x32画像での出力サイズ | 発生しうる問題 |
|---|---|---|---|
| Conv1 |
kernel_size=7, stride=2
|
16x16 | 最初の畳み込みで既に情報が粗くなる |
| MaxPool |
kernel_size=3, stride=2
|
8x8 | さらに画像サイズが1/4に縮小 |
仮説: 特徴抽出が本格的に始まる前に画像サイズが8x8まで縮小されてしまうと、重要な低レベルな情報が失われ、精度向上の妨げになっている可能性があります。
2. 具体的変更:Stem層の改造
この情報損失を防ぐため、ResNetをCIFARタスクで用いる際の一般的な手法に従い、初期のダウンサンプリング処理を緩和します。
| 項目 | 変更前 (標準 ResNet) | 変更後 (CIFAR ResNet) | 目的 |
|---|---|---|---|
| Conv1 |
kernel_size=7, stride=2
|
kernel_size=3, stride=1, padding=1
|
画像サイズを維持しつつ、特徴抽出を行う。 |
| MaxPool | 削除 | 削除 | サイズ縮小をボトルネック層に委ねる。 |
この変更により、最初のステージが終了するまで画像サイズは32x32を維持し、モデルがより多くの情報を保持した状態で学習を開始できるようになります。
3. 実装コードの変更箇所
ResNet クラスの __init__ メソッド内の以下のコードを変更しました。
# 変更前 (標準的な実装)
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 変更後 (CIFAR-10向け Stem構造)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # 3x3, stride 1に変更
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
# MaxPool層は削除し、self.res1の前に挟む層はなしとする
# ... (以降のBottleneck層の定義は変更なし)
- 値の改善度合いは実験工数の都合上後程掲載します。
📝 次回以降の改善施策
構造変更後のモデルで再度学習を行い、精度の変化を確認します。次の施策は以下の通りです。
-
Optimizerの比較検証: 現在使用しているSGDから、AdamやAdamWといった適応的学習率手法に変更し、収束速度と最終的な汎化性能を比較します。
-
Test Time Augmentation (TTA): 精度改善の最終手段として、TTAを導入します。これは、モデルが完成した後の「推論時」に行う施策であるため、最後に取り組むこととします。
📅 おわりに
今回は、モデル構造の根本的な問題を解決する施策を適用しました。この構造変更が、次なる精度向上にどれだけ貢献するか、結果が楽しみです。
最後まで見てくださってありがとうございました!