記事概要
深層学習 (機械学習) がうまくいかない時あるある備忘録です。
本記事では深層学習がうまくいかない原因を以下の2通りに場合わけして考えます。
- モデルパラメータがうまいこと獲得されないケース
- モデルパラメータは獲得されるけどそれ以外の部分 (前処理や後処理) でミスっているケース
ツッコミ歓迎です。そのうち加筆するかもです。
モデルパラメータがうまいこと獲得されないケース
ハイパーパラメータ編
学習率が大きすぎる/小さすぎる
深層学習において学習率の設定はめちゃくちゃ重要です。200エポックかけて到達したロスに、学習率を変えたら10イテレーションで到達した❗️みたいなことがザラにおきます (おきました) 。論文の再現をしたい時には著者の使っている学習率を使うのが丸いです。僕の場合は次のような順序で探索してそれっぽいパラメタを見つけることが多いです。
- 著者のリポジトリで使われているパラメタ
- 論文に書いてあるパラメタ
- その他再現実装のパラメタ
- 類似モデルで使われているパラメタ
- 気合で神パラメタを引く (grid search / optuna)
ちなみに学習率とバッチサイズには関係があり、(例えば https://arxiv.org/pdf/1711.00489.pdf など) バッチサイズを n 倍した場合、学習率も n 倍する必要があります。論文のバッチサイズは GPU クラスタを使用することを前提にしていることも多々あるので、GPU 1台で細々と学習するときはバッチサイズに加えて学習率を調整することをお忘れなく。
データ編
GT の生成をミスる
クラス分類程度なら GT (ground truth / 正解データ) を作る必要はないですが、物体検出などやや入り組んだことをやろうとすると GT の作成をミスることが往々にしてあります。例えば CenterNet (https://arxiv.org/abs/1904.07850) だと、(中心座標、幅、オフセット) のテンソルを GT として作ってやる必要があるのですが、まあ盛大なバグを埋め込むことに成功しました。生成した GT はちゃんと目視確認しましょう。
前処理や後処理でミスっているケース (モデルパラメータは良い)
データ編
検証用データをオーグメントしてしまう
物体検出など位置が重要なアルゴリズムはクリティカルになることがあります。後処理コードがオーグメントしない前提で書かれているにもかかわらず、入力がオーグメントされている場合などですね。まあまあ酷いことが起きます。後処理がオーグメントを考慮できるようにしてあるにしても、一般にオーグメントするとモデルの性能は落ちますので正確な評価のためにも検証用 / テスト用データはオーグメントしないようにしましょう。
RGB と BGR を間違える
あまりないケースではあるのですが、PIL.Image.open
と cv2.imread
はチャンネルが逆転しているので不幸な事故が起きることがあります。前者は 'RGB'
、 後者は 'BGR'
です。訓練時は cv2
で読んでテスト時は PIL
で読むみたいな実装になっている場合問題になります。モデルに流す前に transpose する必要がありますね。平均分散を使った標準化処理をしている場合も問題がおきます。PCA color augmentation とかでも問題が起きるかな。
モデル編
最後の活性化関数を忘れる
僕は普段 pytorch で実装しているのですが、テスト時に最後の活性化関数をつけ忘れることが時々あります。例えば物体認識だと torch.nn.CrossEntropyLoss
を使うかと思うのですが、これを使ってるとテスト時に nn.Softmax
をつけ忘れたりすることがあります。これに関しては出力のレンジを調べれば簡単にチェックできるかなという感じです。
後処理編
出力を元の空間の縮尺に戻し忘れる
物体検出なんかでやりがちなのですが、(3, 512, 512)
の入力が (80, 128, 128)
の出力で出てくるみたいな、空間方向に縮尺が変化するモデルを使っている時に起こります。これをやらかすと、物体検出だと GT との IoU が著しく小さくなるのでいくらモデルパラメータが正しく学習できていても mAP が地を這う現象が起きます。学習データの GT を元の空間に戻してみてきちんと元の値に戻るかをチェックしてみるといいと思います。
評価指標編
そもそも評価指標が間違ってる
「学習してるのに評価値 (accuracy / mAP / MSE など) が全然上がらんぞ❗️」ということは往々にしてあります。この時ありがちなのがそもそも評価指標の実装を間違えているケース。間違え方は多種多様なので具体的にどうしろとかはないのですが、オープンソースの評価指標がある場合はそれを使った方が丸いです。実装する時も評価指標に合わせて設計するといいのかなと思います。オープンソースの評価指標はいろいろありますが、github ならスターの数が3桁以上ならまず信用していいのではないかなと思います。僕が指標探す時の優先順位は次のような感じです。
- データセット公式
- github のリポジトリ
- 気合で実装
3 は本当に最終手段でなるべくやらない方がいいです。あまりにも本質的でないところに時間が取られ過ぎてしまうので…… 実装の練習をしたいとかでない限り回避した方がよい択だと思います。「どうしても自分のリポジトリで完結したいんや❗️」という人はすでにあるリポジトリをフォークしましょう。