ResNet50実装記録 Vol.2:W&Bで実験管理をはじめる
はじめに
こんにちは、DL(ディープラーニング)を勉強している大学1年生です。
2025年の4月からプログラミングを本格的に勉強し始め、GCI(東大松尾研が提供するデータサイエンス入門講座)を修了し、現在DL基礎講座を受講しています。
今回は、DL基礎講座の最終課題に向けた学習のアウトプットとして、ResNet50を用いた画像分類に挑戦します。
この記事は、自分自身のポートフォリオとして ResNet50の実装と、その精度改善施策 を行うプロセスを記録し、共有することを目的としています。
※この記事は実装と改善のプロセスに焦点を当てており、ResNetの理論的な詳細には深く立ち入りません。理論的背景については、適切な参考文献をご参照ください。
前回の記事ではゼロからResNet50を構築しました。
今回は実験結果を効率的に保存・可視化するために、W&B (Weights & Biases) を導入していきます。
なぜW&Bなのか?
実験管理ツールを選定するにあたり、Geminiに相談して比較表を作成しました。
| 項目 | W&B (Weights & Biases) | MLflow | CSV/JSON |
|---|---|---|---|
| 提供形態 | SaaS(クラウド) | オープンソース(セルフホスト) | 手法(ファイルI/O) |
| 初期費用 | 無料 (Personalプラン) | 無料 (ソフトウェア) | 無料 |
| 運用費用 | 無料 (Personalプラン) | インフラ費 or マネージドサービス費 | なし |
| セットアップ | 非常に簡単 (pip + login) | 手間がかかる (サーバー構築) | 不要 |
| 可視化 | ◎ (非常にリッチ・自動) | ◯ (基本的なグラフ) | × (自力で実装) |
| 実験比較 | ◎ (Web UIで簡単) | ◯ (Web UIで可能) | × (非常に困難) |
もともとはローカル環境でCSV等で管理する予定でしたが、セットアップの手軽さと可視化機能の強力さに惹かれ、今回は W&B (Weights & Biases) を採用することにしました。
実装
Google Colab環境での実装手順です。
1. ライブラリのインストール
まずはW&Bのライブラリをインストールします。
!pip install wandb -q
2. W&B にログイン
インストールが完了したらログインします。以下のコードを実行するとAPIキーの入力を求められるので、W&Bの公式サイトで取得したキーを入力します。
import wandb
wandb.login()
3. プロジェクトの初期化 (wandb.init)
実験を開始する際に wandb.init を呼び出します。 ここでハイパーパラメータ(config)を辞書形式で渡しておくことで、後からWeb画面上で条件ごとの精度比較が簡単にできるようになります。
run = wandb.init(
project="classification-of-cifar-10-by-ResNet50",
config={
"learning_rate": 0.01,
"epochs": 10,
"batch_size": 128,
"momentum": 0.9,
"architecture": "ResNet50",
# 前処理の設定なども記録しておくと便利
"preprocessing": "ZCA whitening",
"use_whitening": False,
}
)
4. 学習・検証ログの送信 (wandb.log)
学習ループや検証(Test)のタイミングで wandb.log を使って指標を送信します。これだけでWebダッシュボードにグラフが自動生成されます。
# 検証フェーズなどで指標を計算した後に実行
wandb.log({
"test/accuracy": acc,
"test/precision": precision,
"test/recall": recall,
"test/f1": f1,
"test/roc_auc": roc_auc
})
おわりに
今回は、実験管理ツールとして W&B を導入し、学習ログをクラウド上で保存・可視化する環境を整えました。
これにより、学習率やバッチサイズなどのハイパーパラメータを変更した際に、精度がどのように変化するかを直感的に比較できるようになります。地味な作業ですが、試行錯誤を繰り返すDeep Learningの実験において、こうした「足場固め」は非常に重要だと感じています。
次回はいよいよ、構築したResNet50モデルとこの実験環境を使って ベースラインモデルの学習 を行います。
その後、記事の冒頭で挙げた「正規化」「乱数固定」「転移学習」などの改善施策を順次適用し、CIFAR-10の分類精度がどこまで向上するのかを検証していきます。
最後まで読んでいただきありがとうございました!