LoginSignup
1
0

timm + SyncBatchNorm

Last updated at Posted at 2023-12-06

結論

timm モデルを用いて SyncBatchNorm を実現する場合は以下のユーティリティを利用する。

model = timm.layers.norm_act.convert_sync_batchnorm(model)

経緯

複数のGPUで学習を行う際に1GPUあたりのバッチサイズが小さいと学習が安定しないことがある。この対策として DistributedDataParallel であれば BatchNorm統計量を GPU間で同期させる SyncBatchNorm がレイヤーが存在する。

利用方法としてはモデルの BatchNormXd を SyncBatchNormXd に変更して、保持している統計量などのパラメータをコピーすればよい。また、torchvisionなどのモデルであれば下記のユーティリティが利用できる。

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

この関数は model の中にある BatchNormXd のインスタンスを見つけると SyncBatchNormXd に置き換えるという機能を提供している。

問題

timm モデルにおいて torch の提供する convert_sync_batchnorm でモデルを変換すると学習ができなくなるという問題が発生した。本来 SyncBatchNorm が余計なことをしない Single-GPU の環境であっても生じたため SyncBatchNorm に変更する過程で何らかの問題が生じていることが判る。

原因

timm のモデルは BatchNormAct2d という形で BatchNorm2d を継承して Activation を統合したレイヤーを用いるようになっている。

先のユーティリティは model 内の BatchNorm のインスタンスを自動的に SyncBatchNorm に置き換えるものであり、isinstanceで検索されていることから継承したレイヤーも対象となる。timm のモデルの BatchNormActXd は BatchNormXd を継承したレイヤーとして作られているため、torch のユーティリティを使うと SyncBatchNormXd に置き換えられる。
結果として存在した Activation レイヤーが消されてしまい想定したモデルと異なるモデルになり、計算結果が破壊される。これは issue でも取り上げられていた。

timm.layers.norm_act.convert_sync_batchnorm(model) を使うと BatchNormAct2d が SyncBatchNormAct2d として変換されるので、こちらを使う必要がある。

この問題は本質的には timm 固有の問題ではなく
BatchNorm を継承したレイヤーを自分で定義した際にも起きうるため注意が必要。

PyTorch Lightningについて

lightning においても layer_sync.py にて同じく torch のユーティリティを用いているので timm モデルを使用する場合は自前で変換が必要となる。

参考

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0