7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchでの分散学習時にはDistributedSamplerを指定することを忘れない!

Posted at

PyTorch DDPでのマルチプロセス分散学習時のデータセットの指定方法について誤解していたので動作挙動を示したメモ。

TL;DR

分散学習時にDataLoaderを作成するとき、samplerオプションにDistributedSamplerを指定しないとプロセス間でミニバッチサンプルを分割してくれないので注意(同じデータが各プロセスに送られる)

挙動の確認

データセットの用意

まず初めにデータセットを(適当に)作る。0から99までの整数をintとfloatで返すデータセットを作る。学習データとラベルデータの対を返すイメージ(二つ返ればなんでもよい)

train_dataset = torch.utils.data.TensorDataset(
                        torch.from_numpy(np.arange(100)),
                        torch.from_numpy(np.arange(100.))
                        )

データローダの作成(非分散学習)

バッチサイズを5に指定してDataLoaderを作成する。

通常のシングルノード(プロセス)の学習の場合、作成したデータセットをDataLoader作成時に指定しておけばOK。挙動を追うためにshuffleオプションをFalseにしておく。

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=5, 
                                   shuffle=False, 
                                             )

データセットからバッチを返す処理を確認。

for ix, (data, label) in enumerate(train_loader):
    print(ix, data, label)

実行結果は

0 tensor([0, 1, 2, 3, 4]) tensor([0., 1., 2., 3., 4.], dtype=torch.float64)
1 tensor([5, 6, 7, 8, 9]) tensor([5., 6., 7., 8., 9.], dtype=torch.float64)
2 tensor([10, 11, 12, 13, 14]) tensor([10., 11., 12., 13., 14.], dtype=torch.float64)
3 tensor([15, 16, 17, 18, 19]) tensor([15., 16., 17., 18., 19.], dtype=torch.float64)
4 tensor([20, 21, 22, 23, 24]) tensor([20., 21., 22., 23., 24.], dtype=torch.float64)
5 tensor([25, 26, 27, 28, 29]) tensor([25., 26., 27., 28., 29.], dtype=torch.float64)
...
(略)

各バッチごとに整数のtensor(data)と小数のtensor(label)が帰ってくるのが確認できる。

分散学習時のDataLoader

DataLoaderの作成(失敗例)

上のようにして作ったDataLoaderをプロセス数2の分散学習にそのまま使ってみる。



#### training
    print("start training")
    for idx, (data, label) in enumerate(trainloader):
        print(idx, data, label)
        
    print("completed!!!")

出力結果は

z03byxva8i-algo-2-w2c8j | start training
12a15d5fuc-algo-1-w2c8j | start training
12a15d5fuc-algo-1-w2c8j | 0 tensor([0, 1, 2, 3, 4]) tensor([0., 1., 2., 3., 4.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 0 tensor([0, 1, 2, 3, 4]) tensor([0., 1., 2., 3., 4.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 1 tensor([5, 6, 7, 8, 9]) tensor([5., 6., 7., 8., 9.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 1 tensor([5, 6, 7, 8, 9]) tensor([5., 6., 7., 8., 9.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 2 tensor([10, 11, 12, 13, 14]) tensor([10., 11., 12., 13., 14.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 2 tensor([10, 11, 12, 13, 14]) tensor([10., 11., 12., 13., 14.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 3 tensor([15, 16, 17, 18, 19]) tensor([15., 16., 17., 18., 19.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 3 tensor([15, 16, 17, 18, 19]) tensor([15., 16., 17., 18., 19.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 4 tensor([20, 21, 22, 23, 24]) tensor([20., 21., 22., 23., 24.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 4 tensor([20, 21, 22, 23, 24]) tensor([20., 21., 22., 23., 24.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 5 tensor([25, 26, 27, 28, 29]) tensor([25., 26., 27., 28., 29.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 5 tensor([25, 26, 27, 28, 29]) tensor([25., 26., 27., 28., 29.], dtype=torch.float64)
...
z03byxva8i-algo-2-w2c8j | 17 tensor([85, 86, 87, 88, 89]) tensor([85., 86., 87., 88., 89.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 17 tensor([85, 86, 87, 88, 89]) tensor([85., 86., 87., 88., 89.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 18 tensor([90, 91, 92, 93, 94]) tensor([90., 91., 92., 93., 94.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | 18 tensor([90, 91, 92, 93, 94]) tensor([90., 91., 92., 93., 94.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | 19 tensor([95, 96, 97, 98, 99]) tensor([95., 96., 97., 98., 99.], dtype=torch.float64)
z03byxva8i-algo-2-w2c8j | completed!!!
12a15d5fuc-algo-1-w2c8j | 19 tensor([95, 96, 97, 98, 99]) tensor([95., 96., 97., 98., 99.], dtype=torch.float64)
12a15d5fuc-algo-1-w2c8j | completed!!!

パイプ|の左側がプロセス名。この場合z03...12a...二つのプロセスが走っている。二つのプロセスにまったく同じバッチがフルサンプル分渡っていることがわかる。

DistributedSampler

サンプルをプロセスごとにうまく配分して送るためにはDistributedSamplerを使う。

import torch.distributed as dist

train_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset, 
                        num_replicas=dist.get_world_size(), 
                        rank=dist.get_rank(),
                        shuffle=False,

num_replicasはプロセス数、rankは各プロセスを識別するid。
torch.distributedを使って取得する。

実験のためここでもshuffleをFalseにしておく。

DataLoaderの作成(成功例)

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=5, 
                    shuffle=train_sampler is None,  
                    sampler=train_sampler,
                    )

分散学習時はsamplerオプションに先ほど作成したDistributedDataSamplerを指定しておく。

DataLoaderにsamplerを指定した場合、shuffleの有無はDataLoaderでなくSampler側で指定する。そのためDataLoader側でshuffleをTrueに指定した場合エラーがでる。これを防ぐために分散学習のon/off(正確にはDistributedSamplerの有/無(None))に応じてshuffleオプションを切り替えるような記述が多い

実行結果は

nn1i3nw2tt-algo-2-u6txj | start training
le9iijpjis-algo-1-u6txj | start training
nn1i3nw2tt-algo-2-u6txj | 0 tensor([1, 3, 5, 7, 9]) tensor([1., 3., 5., 7., 9.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 0 tensor([0, 2, 4, 6, 8]) tensor([0., 2., 4., 6., 8.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 1 tensor([11, 13, 15, 17, 19]) tensor([11., 13., 15., 17., 19.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 1 tensor([10, 12, 14, 16, 18]) tensor([10., 12., 14., 16., 18.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 2 tensor([20, 22, 24, 26, 28]) tensor([20., 22., 24., 26., 28.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 2 tensor([21, 23, 25, 27, 29]) tensor([21., 23., 25., 27., 29.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 3 tensor([30, 32, 34, 36, 38]) tensor([30., 32., 34., 36., 38.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 3 tensor([31, 33, 35, 37, 39]) tensor([31., 33., 35., 37., 39.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 4 tensor([41, 43, 45, 47, 49]) tensor([41., 43., 45., 47., 49.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 4 tensor([40, 42, 44, 46, 48]) tensor([40., 42., 44., 46., 48.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 5 tensor([51, 53, 55, 57, 59]) tensor([51., 53., 55., 57., 59.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 5 tensor([50, 52, 54, 56, 58]) tensor([50., 52., 54., 56., 58.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 6 tensor([61, 63, 65, 67, 69]) tensor([61., 63., 65., 67., 69.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 6 tensor([60, 62, 64, 66, 68]) tensor([60., 62., 64., 66., 68.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 7 tensor([71, 73, 75, 77, 79]) tensor([71., 73., 75., 77., 79.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 7 tensor([70, 72, 74, 76, 78]) tensor([70., 72., 74., 76., 78.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 8 tensor([81, 83, 85, 87, 89]) tensor([81., 83., 85., 87., 89.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 8 tensor([80, 82, 84, 86, 88]) tensor([80., 82., 84., 86., 88.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | 9 tensor([91, 93, 95, 97, 99]) tensor([91., 93., 95., 97., 99.], dtype=torch.float64)
nn1i3nw2tt-algo-2-u6txj | completed!!!
le9iijpjis-algo-1-u6txj | 9 tensor([90, 92, 94, 96, 98]) tensor([90., 92., 94., 96., 98.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | completed!!!

各プロセスで5サンプル*10バッチのミニバッチが渡され、全体として100個のサンプルが渡されていることがわかる。

また、shuffle=Falseの時の挙動は、全体のサンプルを一つずつ順番に割り当てていることがわかる。

nn1i3nw2tt-algo-2-u6txj | 0 tensor([1, 3, 5, 7, 9]) tensor([1., 3., 5., 7., 9.], dtype=torch.float64)
le9iijpjis-algo-1-u6txj | 0 tensor([0, 2, 4, 6, 8]) tensor([0., 2., 4., 6., 8.], dtype=torch.float64)

0-49をプロセス1、50-99をプロセス2に振り分け、という挙動はしない様子。

最終バッチの処理

バッチサイズとプロセス数の組み合わせによっては、最後のミニバッチが中途半端になることがある。

例えばサンプルサイズ100、ミニバッチサイズ5、分散学習のプロセス数3の場合。

1回のバッチで3*5= 15サンプル消費される。6回のバッチで90サンプル消費され、7回目のバッチの場合は残りの10サンプルを3プロセスに配給することになる。

この場合、各3プロセスには3ないし4つのバッチが配給されることが予想される。実際の挙動をみてみる(抜粋)。

c4uvie3yv8-algo-1-ojhpc | 6 tensor([90, 93, 96, 99]) tensor([90., 93., 96., 99.], dtype=torch.float64)
mwd4bv7z02-algo-2-ojhpc | 6 tensor([91, 94, 97,  0]) tensor([91., 94., 97.,  0.], dtype=torch.float64)
fqvsn1ygjt-algo-3-ojhpc | 6 tensor([92, 95, 98,  1]) tensor([92., 95., 98.,  1.], dtype=torch.float64)

ミニバッチサイズは5だが、バッチ最終端の場合は4つに短縮される。ただし、各プロセスに4つのサンプルを配給するため、残り10個に対して12個必要で2個たりない。その分は、すでに(過去のバッチで)使用したサンプル(この場合サンプル値0と1)を拝借している。

そのため、バッチ数とプロセス数を変えたときに、1エポックあたりのミニバッチが変わり実験の条件が微妙に変わってしまうことがある。

これがイヤな場合はDistributedSampler作成時にdrop_lastオプションをTrueにする。

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 
                               num_replicas=dist.get_world_size(), 
                               rank=dist.get_rank(),
                               shuffle=False,
                               drop_last=True
                               )
Process 1 | 6 tensor([90, 93, 96]) tensor([90., 93., 96.], dtype=torch.float64)
Process 2 | 6 tensor([91, 94, 97]) tensor([91., 94., 97.], dtype=torch.float64)
Process 3 | 6 tensor([92, 95, 98]) tensor([92., 95., 98.], dtype=torch.float64)

最後のバッチで、サンプルの再利用がないことがわかる。ただし、全てのサンプルサイズを揃えるため、逆に使っていないサンプル(99)が出てくる。

ちなみにDistributedDataLoaderのdrop_lastオプションはver. 1.8.0以降でないと実装されていないので注意

DataLoaderのdrop_lastとDistributedSamplerのdrop_lastの違い

ところでDataLoaderにもdrop_lastオプションがある。

DataLoaderのdeop_lastオプションは最後の中途半端なバッチをバッサリとカットしてしまう(使用しない)ため、挙動が違うことに注意

Process 1 | 5 tensor([75, 78, 81, 84, 87]) tensor([75., 78., 81., 84., 87.], dtype=torch.float64)
completed!!!
Process 2 | 5 tensor([76, 79, 82, 85, 88]) tensor([76., 79., 82., 85., 88.], dtype=torch.float64)
completed!!!
Process 3 | 5 tensor([77, 80, 83, 86, 89]) tensor([77., 80., 83., 86., 89.], dtype=torch.float64)
completed!!!

最後に

とにかく分散学習の時はDistributedSamplerを指定することを忘れないようにしましょう。

7
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?