結論
timm モデルを用いて SyncBatchNorm を実現する場合は以下のユーティリティを利用する。
model = timm.layers.norm_act.convert_sync_batchnorm(model)
経緯
複数のGPUで学習を行う際に1GPUあたりのバッチサイズが小さいと学習が安定しないことがある。この対策として DistributedDataParallel であれば BatchNorm統計量を GPU間で同期させる SyncBatchNorm がレイヤーが存在する。
利用方法としてはモデルの BatchNormXd を SyncBatchNormXd に変更して、保持している統計量などのパラメータをコピーすればよい。また、torchvisionなどのモデルであれば下記のユーティリティが利用できる。
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
この関数は model の中にある BatchNormXd のインスタンスを見つけると SyncBatchNormXd に置き換えるという機能を提供している。
問題
timm モデルにおいて torch の提供する convert_sync_batchnorm でモデルを変換すると学習ができなくなるという問題が発生した。本来 SyncBatchNorm が余計なことをしない Single-GPU の環境であっても生じたため SyncBatchNorm に変更する過程で何らかの問題が生じていることが判る。
原因
timm のモデルは BatchNormAct2d という形で BatchNorm2d を継承して Activation を統合したレイヤーを用いるようになっている。
先のユーティリティは model 内の BatchNorm のインスタンスを自動的に SyncBatchNorm に置き換えるものであり、isinstanceで検索されていることから継承したレイヤーも対象となる。timm のモデルの BatchNormActXd は BatchNormXd を継承したレイヤーとして作られているため、torch のユーティリティを使うと SyncBatchNormXd に置き換えられる。
結果として存在した Activation レイヤーが消されてしまい想定したモデルと異なるモデルになり、計算結果が破壊される。これは issue でも取り上げられていた。
timm.layers.norm_act.convert_sync_batchnorm(model) を使うと BatchNormAct2d が SyncBatchNormAct2d として変換されるので、こちらを使う必要がある。
この問題は本質的には timm 固有の問題ではなく
BatchNorm を継承したレイヤーを自分で定義した際にも起きうるため注意が必要。
PyTorch Lightningについて
lightning においても layer_sync.py にて同じく torch のユーティリティを用いているので timm モデルを使用する場合は自前で変換が必要となる。
参考
- timm - GitHub
- PyTorch Documentation
- pytorch-lightning - GitHub