🚀 ResNet50実装記録 Vol.4:データ拡張 (Data Augmentation) の追加
はじめに
こんにちは、DL(ディープラーニング)を勉強している大学1年生です。
2025年の4月からプログラミングを本格的に勉強し始め、GCI(東大松尾研が提供するデータサイエンス入門講座)を修了し、現在DL基礎講座を受講しています。
今回は、DL基礎講座の最終課題に向けた学習のアウトプットとして、ResNet50を用いた画像分類に挑戦します。
この記事は、自分自身のポートフォリオとして ResNet50の実装と、その精度改善施策 を行うプロセスを記録し、共有することを目的としています。
※この記事は実装と改善のプロセスに焦点を当てており、ResNetの理論的な詳細には深く立ち入りません。理論的背景については、適切な参考文献をご参照ください。
前回の記事ではW&B (Weights & Biases) を活用した実験トラッキング環境を完成させました。今回は、その環境を使い、学習データの前処理を変更する改善施策を適用し、ベースラインモデルの精度向上を目指します。
🎯 改善施策の立案と選定理由
1. Optimizer / Scheduler の選定方針(継続)
既存の学習手法については、引き続き高い汎化性能を目指し、以下の構成を継続します。
| 項目 | 設定値 | 選定理由 |
|---|---|---|
| Optimizer | SGD (Momentum 0.9) | 汎化性能の向上: Adamなどの適応的学習率手法は収束が早いものの、画像分類タスクにおいては最終的な**汎化性能(Generalization)**がSGDに劣る傾向があるため (Wilson et al., 2017)。今回は最終的なTest Accuracyの向上を最優先します。 |
| Weight Decay | 適用 | 過学習抑制: 重み減衰を適用し、モデルのスパース性を高めることで過学習を抑制します。 |
| Scheduler | CosineAnnealingLR | 学習率をコサイン曲線に従って徐々に減少させ、収束を安定させるため。 |
| 目的関数 | Cross Entropy | まずはこのままベースラインを確立し、過学習が確認された場合は Label Smoothing の導入を検討します。 |
参考文献: The Marginal Value of Adaptive Gradient Methods in Machine Learning (Wilson et al., 2017) https://arxiv.org/abs/1705.08292
2. データ拡張 (Data Augmentation) の追加
今回の施策として、モデルが様々なパターンのデータに対応できるよう、データ拡張を適用します。
施策の概要
データセットの枚数が限られているCIFAR-10に対して、訓練用データの前処理に以下の拡張を加えます。
- RandomCrop (ランダム切り抜き): 画像をランダムに切り抜くことで、物体が画像のどこに写っていても認識できる能力を向上させる。
- RandomHorizontalFlip (ランダム水平反転): 画像を水平方向にランダムに反転させることで、モデルに左右反転した特徴も学習させる。
TrivialAugmentの適用検討
より強力な拡張手法として TrivialAugment(ランダムに選んだ単一の拡張を適用する手法)の採用を検討しましたが、まずはシンプルな RandomCrop と RandomHorizontalFlip を適用した上で、ベースライン2を構築します。その結果を見て、必要に応じて強力なAutoAugment系の手法を導入する方針とします。
変更予定のコード(transforms の修正)
学習データの前処理部分を以下のように修正し、データ拡張を有効化します。
transform_train = transforms.Compose([
# ランダムに32x32を切り抜く
transforms.RandomCrop(32, padding=4),
# ランダムに画像を水平反転
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
3. データ拡張 (Data Augmentation) の追加結果
前回の記事で提案した RandomCrop と RandomHorizontalFlip よりも強力な手法である TrivialAugment を実装し、ベースラインモデルと比較検証を行いました。
以下の表が、ベースライン(素の状態)とTrivialAugmentを適用したモデル(エポック数を増加)の比較結果です。
| 指標 | ベースライン (Epoch 5) | 今回 (Epoch 20, TrivialAugment) | 評価 |
|---|---|---|---|
| Train Accuracy | 92.71% | 73.37% | ⬇️ 大幅に低下 |
| Val Accuracy | 79.39% | 83.34% | ⬆️ 約 +4% 改善 (SOTA級への第一歩) |
| 状態 | 過学習 (Overfitting) | 高汎化 (High Generalization) | 理想的な状態に近づいた |
結果の考察
TrivialAugmentの導入により、学習データ(Train)の精度は大きく下がりましたが、検証データ(Val)の精度は4%近く向上しました。
これは、学習時に画像に様々な変化を加えたことで、モデルが訓練データを丸暗記(過学習)するのを防ぎ、未知のデータに対する汎化性能が大きく高まったことを示しています。
- Train Lossが高いのは計算内:学習時により難しい多様なデータを見ているため、Train Loss(およびAccuracy)が下がるのは自然な結果であり、汎化に成功した証拠と言えます。
4. 次回以降の改善施策
過学習の問題が大きく改善されたため、今後はモデルのポテンシャルを最大限に引き出すための施策に移行します。
① 学習エポック数の増加
TrivialAugmentの適用により、モデルがまだ学習できる余地があることが示唆されました。今後は、エポック数を大幅に引き上げて、Train LossがVal Lossに再び近づくまで学習を継続します。
② 推論時のデータ拡張 (TTA: Test Time Augmentation) の導入
推論時にもデータ拡張(水平反転など)を複数回適用し、その予測結果を平均化する TTA (Test Time Augmentation) を導入します。これにより、予測時の安定性と精度向上が期待できます。
③ モデル構造の最適化(CIFAR用 ResNetへの改造)
現状使用している標準の torchvision.models.resnet50 は、224x224ピクセル(ImageNetサイズ)の画像を想定して設計されています。
CIFAR-10のような 32x32の非常に小さい画像 にそのまま使用すると、**最初の層(Stem)**で画像サイズが急激に縮小され、重要な情報が失われている可能性があります。
-
標準 ResNet (ImageNet用):
7x7 Conv(stride 2) +MaxPool→ 画像サイズがいきなり1/4になる。 -
改善策 (CIFAR用):
3x3 Conv(stride 1) + MaxPoolなし → 画像サイズを維持して特徴を抽出し、情報ロスを防ぐ。
次回は、このResNetのStem部分をCIFAR用に改造し、モデルの構造的なボトルネックを解消することを目指します。
さいごに
TrivialAugmentの導入により、ベースラインから大きな一歩を踏み出すことができました。今後はエポック数の増加やTTA、そしてモデル構造の最適化といった高次の施策に取り組んでいきます。
最後まで見てくださってありがとうございました!